fix RAM OOM when load large models in tensor parallel mode. (#1395)

Co-authored-by: ran_lin <rlin@thoughtworks.com>
This commit is contained in:
boydfd
2023-11-21 11:02:42 +08:00
committed by GitHub
parent 819b18e7ba
commit 4bb6b67188
4 changed files with 52 additions and 7 deletions

View File

@ -285,10 +285,12 @@ class ParallelConfig:
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:

View File

@ -22,6 +22,7 @@ class EngineArgs:
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
@ -128,6 +129,12 @@ class EngineArgs:
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
@ -195,7 +202,8 @@ class EngineArgs:
getattr(model_config.hf_config, 'sliding_window', None))
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
self.worker_use_ray,
self.max_parallel_loading_workers)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,

View File

@ -143,6 +143,12 @@ class LLMEngine:
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
@ -182,6 +188,12 @@ class LLMEngine:
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@ -682,16 +694,15 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _run_workers(
def _run_workers_in_batch(
self,
workers,
method: str,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
):
all_outputs = []
for worker in self.workers:
for worker in workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
@ -699,9 +710,31 @@ class LLMEngine:
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
return all_outputs
def _run_workers(
self,
method: str,
*args,
get_all_outputs: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
if max_concurrent_workers:
work_groups = [
self.workers[i:i + max_concurrent_workers]
for i in range(0, len(self.workers), max_concurrent_workers)
]
else:
work_groups = [self.workers]
for workers in work_groups:
all_outputs.extend(
self._run_workers_in_batch(workers, method, *args, **kwargs))
if get_all_outputs:
return all_outputs

View File

@ -67,6 +67,8 @@ class Worker:
# Initialize the model.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model = get_model(self.model_config)
@torch.inference_mode()