[TPU] Rename tpu_commons to tpu_inference (#26279)

Signed-off-by: Utkarsh Sharma <utksharma@google.com>
Co-authored-by: Utkarsh Sharma <utksharma@google.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Utkarsh Sharma
2025-10-08 12:00:52 +05:30
committed by GitHub
parent 5e65d6b2ad
commit 335b28f7d1
6 changed files with 22 additions and 22 deletions

View File

@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
from .base_device_communicator import DeviceCommunicatorBase
@ -20,8 +20,8 @@ USE_RAY = parallel_config = (
logger = init_logger(__name__)
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
@ -100,9 +100,9 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim)
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator,
if USE_TPU_INFERENCE:
from tpu_inference.distributed.device_communicators import (
TpuCommunicator as TpuInferenceCommunicator,
)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
TpuCommunicator = TpuInferenceCommunicator # type: ignore

View File

@ -223,9 +223,9 @@ class DefaultModelLoader(BaseModelLoader):
)
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
if not USE_TPU_COMMONS:
if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program.

View File

@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]:
# Check for Pathways TPU proxy
if envs.VLLM_TPU_USING_PATHWAYS:
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
return "tpu_inference.platforms.tpu_jax.TpuPlatform"
# Check for libtpu installation
try:

View File

@ -26,7 +26,7 @@ else:
logger = init_logger(__name__)
USE_TPU_COMMONS = False
USE_TPU_INFERENCE = False
class TpuPlatform(Platform):
@ -254,10 +254,10 @@ class TpuPlatform(Platform):
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
TpuPlatform = TpuCommonsPlatform # type: ignore
USE_TPU_COMMONS = True
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
pass

View File

@ -35,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
}
try:
import tpu_commons # noqa: F401
import tpu_inference # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb

View File

@ -23,7 +23,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.core.sched.output import SchedulerOutput
@ -36,8 +36,8 @@ logger = init_logger(__name__)
_R = TypeVar("_R")
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
@ -346,7 +346,7 @@ class TPUWorker:
return fn(self.get_model())
if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
if USE_TPU_INFERENCE:
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
TPUWorker = TPUCommonsWorker # type: ignore
TPUWorker = TpuInferenceWorker # type: ignore