[Bugfix] neuron: enable tensor parallelism (#7562)

Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>
This commit is contained in:
omrishiv
2024-08-26 15:13:13 -07:00
committed by GitHub
parent 05826c887b
commit 760e9f71a8
3 changed files with 44 additions and 11 deletions

View File

@ -317,9 +317,10 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
choices=[8, 16, 32],
help='Token block size for contiguous chunks of '
'tokens.')
'tokens. This is ignored on neuron devices and '
'set to max-model-len')
parser.add_argument('--enable-prefix-caching',
action='store_true',
@ -793,7 +794,8 @@ class EngineArgs:
limit_mm_per_prompt=self.limit_mm_per_prompt,
)
cache_config = CacheConfig(
block_size=self.block_size,
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,

View File

@ -4,7 +4,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
@ -24,14 +25,17 @@ class NeuronExecutor(ExecutorBase):
def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
)
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method)
self.driver_worker.init_device()
self.driver_worker.load_model()

View File

@ -6,6 +6,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.is_driver_worker = True
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment inited when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
1,
1,
)