mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] Support tensor parallelism in async llm engine (#6891)
This commit is contained in:
@ -12,6 +12,9 @@ RUN pip install "numpy<2"
|
||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
||||
# Fix FastAPI dependence
|
||||
RUN pip install "starlette<0.38.0"
|
||||
|
||||
# Build vLLM.
|
||||
COPY . /workspace/vllm
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
|
@ -407,8 +407,14 @@ class AsyncLLMEngine:
|
||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||
executor_class = NeuronExecutorAsync
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
from vllm.executor.tpu_executor import TPUExecutorAsync
|
||||
executor_class = TPUExecutorAsync
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
|
||||
executor_class = RayTPUExecutorAsync
|
||||
else:
|
||||
assert distributed_executor_backend is None
|
||||
from vllm.executor.tpu_executor import TPUExecutorAsync
|
||||
executor_class = TPUExecutorAsync
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutorAsync
|
||||
executor_class = CPUExecutorAsync
|
||||
|
Reference in New Issue
Block a user