[TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856)

This commit is contained in:
Woosuk Kwon
2024-07-27 10:28:33 -07:00
committed by GitHub
parent f954d0715c
commit fad5576c58
6 changed files with 24 additions and 7 deletions

View File

@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20240713"
ARG NIGHTLY_DATE="20240726"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE

View File

@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240713"
$ export DATE="+20240726"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. note::
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
The compilation time may take 20~30 minutes in the first run.
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).
.. tip::
If you encounter the following error:

View File

@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)

View File

@ -6,6 +6,7 @@ from vllm.platforms import current_platform
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
@ -20,7 +21,7 @@ class TpuCommunicator:
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xm._init_world_size_ordinal()
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)

View File

@ -7,6 +7,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
@ -127,7 +128,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xm.get_ordinal()
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
@ -146,7 +147,17 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model, backend="openxla", fullgraph=True)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=True)
def _dummy_run(
self,

View File

@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr
import vllm.envs as envs