mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Hardware][TPU] Support MoE with Pallas GMM kernel (#6457)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20240601"
|
||||
ARG NIGHTLY_DATE="20240713"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
@ -6,6 +6,8 @@ WORKDIR /workspace
|
||||
|
||||
# Install aiohttp separately to avoid build errors.
|
||||
RUN pip install aiohttp
|
||||
# Install NumPy 1 instead of NumPy 2.
|
||||
RUN pip install "numpy<2"
|
||||
# Install the TPU and Pallas dependencies.
|
||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
@ -56,7 +56,7 @@ First, install the dependencies:
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="+20240601"
|
||||
$ export DATE="+20240713"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
@ -85,7 +85,7 @@ Next, build vLLM from source. This will only take a few seconds:
|
||||
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory
|
||||
|
||||
|
||||
You can install OpenBLAS with the following command:
|
||||
Please install OpenBLAS with the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
|
@ -104,6 +104,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
raise NotImplementedError(
|
||||
"The CPU backend currently does not support MoE.")
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
62
vllm/model_executor/layers/fused_moe/moe_pallas.py
Normal file
62
vllm/model_executor/layers/fused_moe/moe_pallas.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch_xla.experimental.custom_kernel import _histogram
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [*, hidden_size]
|
||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
"""
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
intermediate_size = w2.shape[-1]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
assert (num_tokens * topk) % 16 == 0, (
|
||||
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
|
||||
f"16 but got {num_tokens * topk}")
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
topk_indices = topk_indices.flatten()
|
||||
topk_argsort_indices = topk_indices.argsort()
|
||||
topk_argsort_revert_indices = topk_argsort_indices.argsort()
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=device).repeat_interleave(topk)
|
||||
token_indices = token_indices[topk_argsort_indices]
|
||||
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
|
||||
|
||||
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
|
||||
# from HF Transformers.
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
|
||||
x = hidden_states[token_indices]
|
||||
x = torch.ops.xla.gmm(x, w1, group_sizes)
|
||||
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
|
||||
x = torch.ops.xla.gmm(x, w2, group_sizes)
|
||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
||||
|
||||
x = x * topk_weights.unsqueeze_(dim=-1)
|
||||
x = x.sum(dim=-2)
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
@ -598,11 +598,10 @@ def _get_padded_prefill_len(x: int) -> int:
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
elif batch_size <= 8:
|
||||
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
||||
# To meet this requirement in the simplest way, we set the minimal batch
|
||||
# size to 8.
|
||||
if batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
|
Reference in New Issue
Block a user