mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20240713"
|
||||
ARG NIGHTLY_DATE="20240726"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
@ -56,7 +56,7 @@ First, install the dependencies:
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="+20240713"
|
||||
$ export DATE="+20240726"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
|
||||
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
|
||||
The compilation time may take 20~30 minutes in the first run.
|
||||
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).
|
||||
|
||||
|
||||
.. tip::
|
||||
|
||||
If you encounter the following error:
|
||||
|
@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
import torch_xla.experimental.dynamo_set_buffer_donor
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
|
@ -6,6 +6,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
|
||||
|
||||
@ -20,7 +21,7 @@ class TpuCommunicator:
|
||||
local_rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
pjrt.initialize_multiprocess(local_rank, world_size)
|
||||
xm._init_world_size_ordinal()
|
||||
xr._init_world_size_ordinal()
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
@ -127,7 +128,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
# determine the order of concatenating the output tensors.
|
||||
# As a workaround, we use the xm's rank assignment only when loading
|
||||
# the embedding weights.
|
||||
xm_tp_rank = xm.get_ordinal()
|
||||
xm_tp_rank = xr.global_ordinal()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
@ -146,7 +147,17 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
xm.wait_device_ops()
|
||||
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model, backend="openxla", fullgraph=True)
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Setting dynamic=True can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=True)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
|
@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
|
Reference in New Issue
Block a user