mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Hardware][Intel GPU] Add intel GPU pipeline parallel support. (#7810)
This commit is contained in:
@ -666,6 +666,11 @@ class AsyncLLMEngine:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
||||
executor_class = RayXPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.multiproc_xpu_executor import (
|
||||
MultiprocessingXPUExecutorAsync)
|
||||
executor_class = MultiprocessingXPUExecutorAsync
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Not supported distributed execution model on XPU device.")
|
||||
|
@ -472,6 +472,13 @@ class LLMEngine:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_xpu_executor import RayXPUExecutor
|
||||
executor_class = RayXPUExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
# FIXME(kunshang):
|
||||
# spawn needs calling `if __name__ == '__main__':``
|
||||
# fork is not supported for xpu start new process.
|
||||
logger.error(
|
||||
"Both start methods (spawn and fork) have issue "
|
||||
"on XPU if you use mp backend, Please try ray instead.")
|
||||
else:
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
executor_class = XPUExecutor
|
||||
|
@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self._check_executor_parameters()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
||||
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
||||
|
||||
@ -68,16 +64,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
if world_size > 1:
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
@ -139,6 +125,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def _check_executor_parameters(self):
|
||||
world_size = self.parallel_config.tensor_parallel_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor",
|
||||
None)) is not None:
|
||||
|
26
vllm/executor/multiproc_xpu_executor.py
Normal file
26
vllm/executor/multiproc_xpu_executor.py
Normal file
@ -0,0 +1,26 @@
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
|
||||
"""Python multiprocessing-based multi-XPU executor"""
|
||||
|
||||
def _check_executor_parameters(self):
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
if mp_method != "spawn":
|
||||
raise RuntimeError(
|
||||
"XPU multiprocess executor only support spawn as mp method")
|
||||
|
||||
|
||||
class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
|
||||
MultiprocessingGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
@ -12,6 +12,7 @@ from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
"Setting it to the minimum value of 1.", expr)
|
||||
max_num_seqs = 1
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
seq_data, dummy_multi_modal_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
self.execute_model(model_input, kv_caches)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.xpu.synchronize()
|
||||
return
|
||||
|
||||
@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start_time = time.time()
|
||||
|
||||
hidden_states = model_executable(
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device))
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end_time = time.time()
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
|
@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.utils import is_xpu
|
||||
@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
if parallel_config.pipeline_parallel_size > 1:
|
||||
# torch-ccl xpu need a collective API warm up
|
||||
# before calling send/recv API
|
||||
get_pp_group().all_reduce(torch.zeros(1).xpu())
|
||||
|
Reference in New Issue
Block a user