mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] Use mark_dynamic to reduce compilation time (#7340)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20240726"
|
||||
ARG NIGHTLY_DATE="20240808"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
@ -56,7 +56,7 @@ First, install the dependencies:
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="+20240726"
|
||||
$ export DATE="+20240808"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
@ -65,7 +65,7 @@ First, install the dependencies:
|
||||
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
||||
$ # Install other build dependencies.
|
||||
$ pip install packaging aiohttp
|
||||
$ pip install -r requirements-tpu.txt
|
||||
|
||||
|
||||
Next, build vLLM from source. This will only take a few seconds:
|
||||
|
@ -147,19 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
|
||||
model = ModelWrapper(model)
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Setting dynamic=True can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=True)
|
||||
self.model = CompiledModelWrapper(model)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
@ -697,6 +685,52 @@ class ModelWrapper(nn.Module):
|
||||
return next_token_ids
|
||||
|
||||
|
||||
class CompiledModelWrapper:
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_lens: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# Prefll
|
||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||
else:
|
||||
# Decode
|
||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
return self.model(token_ids, position_ids, attn_metadata, input_lens,
|
||||
t, p, num_samples, kv_caches)
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
|
Reference in New Issue
Block a user