mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
6 Commits
v0.10.2rc3
...
gpu-ids
Author | SHA1 | Date | |
---|---|---|---|
37cf1f27f2 | |||
45ea3c31a2 | |||
df866cfebf | |||
29596317b0 | |||
7b86860ff5 | |||
f386a9e56c |
@ -19,7 +19,20 @@ llm = LLM(model="ibm-granite/granite-3.1-8b-instruct",
|
|||||||
To ensure that vLLM initializes CUDA correctly, you should avoid calling related functions (e.g. [torch.cuda.set_device][])
|
To ensure that vLLM initializes CUDA correctly, you should avoid calling related functions (e.g. [torch.cuda.set_device][])
|
||||||
before initializing vLLM. Otherwise, you may run into an error like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
before initializing vLLM. Otherwise, you may run into an error like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||||
|
|
||||||
To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable.
|
|
||||||
|
To control which devices are used, you can either set the `CUDA_VISIBLE_DEVICES`
|
||||||
|
environment variable, pass the `gpu_ids` parameter to the [LLM] constructor,
|
||||||
|
or use the `--gpu-ids` option with `vllm serve`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
# Use GPUs 0 and 2 for execution without setting CUDA_VISIBLE_DEVICES env var
|
||||||
|
llm = LLM(
|
||||||
|
model="your-model",
|
||||||
|
gpu_ids=[0, 2],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).
|
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).
|
||||||
|
@ -1115,6 +1115,12 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
|||||||
MoE all2all (DeepEP) usually allocate the communication buffer
|
MoE all2all (DeepEP) usually allocate the communication buffer
|
||||||
based on the model shape for optimal performance.
|
based on the model shape for optimal performance.
|
||||||
"""
|
"""
|
||||||
|
orig = torch.cuda.current_device()
|
||||||
|
for d in range(8):
|
||||||
|
torch.cuda.set_device(d)
|
||||||
|
torch.zeros(1, device=f'cuda:{d}')
|
||||||
|
torch.cuda.set_device(orig)
|
||||||
|
print("pre-warmed all GPUs")
|
||||||
if _TP is not None:
|
if _TP is not None:
|
||||||
_TP.prepare_communication_buffer_for_model(model)
|
_TP.prepare_communication_buffer_for_model(model)
|
||||||
if _PP is not None:
|
if _PP is not None:
|
||||||
|
@ -38,6 +38,9 @@ class ServeSubcommand(CLISubcommand):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cmd(args: argparse.Namespace) -> None:
|
def cmd(args: argparse.Namespace) -> None:
|
||||||
|
# Allow overriding visible GPUs via --gpu-ids (comma-separated or single int)
|
||||||
|
if hasattr(args, 'gpu_ids') and args.gpu_ids is not None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
|
||||||
# If model is specified in CLI (as positional arg), it takes precedence
|
# If model is specified in CLI (as positional arg), it takes precedence
|
||||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||||
args.model = args.model_tag
|
args.model = args.model_tag
|
||||||
@ -98,8 +101,13 @@ class ServeSubcommand(CLISubcommand):
|
|||||||
help="Read CLI options from a config file. "
|
help="Read CLI options from a config file. "
|
||||||
"Must be a YAML with the following options: "
|
"Must be a YAML with the following options: "
|
||||||
"https://docs.vllm.ai/en/latest/configuration/serve_args.html")
|
"https://docs.vllm.ai/en/latest/configuration/serve_args.html")
|
||||||
|
|
||||||
serve_parser = make_arg_parser(serve_parser)
|
serve_parser = make_arg_parser(serve_parser)
|
||||||
|
serve_parser.add_argument(
|
||||||
|
"--gpu-ids",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Comma-separated GPU IDs or a single GPU ID to use for vLLM serve. "
|
||||||
|
"Overrides CUDA_VISIBLE_DEVICES.")
|
||||||
show_filtered_argument_or_group_from_help(serve_parser, ["serve"])
|
show_filtered_argument_or_group_from_help(serve_parser, ["serve"])
|
||||||
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG
|
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG
|
||||||
return serve_parser
|
return serve_parser
|
||||||
|
@ -9,6 +9,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
|||||||
cast, overload)
|
cast, overload)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
|
import os
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
@ -75,6 +76,9 @@ class LLM:
|
|||||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||||
detokenizer. Expect valid prompt_token_ids and None for prompt
|
detokenizer. Expect valid prompt_token_ids and None for prompt
|
||||||
from the input.
|
from the input.
|
||||||
|
gpu_ids: A list of GPU device IDs or a comma-separated string to use
|
||||||
|
for vLLM execution. Overrides the CUDA_VISIBLE_DEVICES environment
|
||||||
|
variable.
|
||||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
downloading the model and tokenizer.
|
downloading the model and tokenizer.
|
||||||
allowed_local_media_path: Allowing API requests to read local images
|
allowed_local_media_path: Allowing API requests to read local images
|
||||||
@ -170,6 +174,7 @@ class LLM:
|
|||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
tokenizer_mode: TokenizerMode = "auto",
|
tokenizer_mode: TokenizerMode = "auto",
|
||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
|
gpu_ids: Optional[Union[Sequence[int], str]] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
allowed_local_media_path: str = "",
|
allowed_local_media_path: str = "",
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
@ -198,6 +203,13 @@ class LLM:
|
|||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
|
|
||||||
|
# Allow specifying GPU device IDs without using CUDA_VISIBLE_DEVICES env var
|
||||||
|
if gpu_ids is not None:
|
||||||
|
# gpu_ids can be a sequence of ints or a string
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = (
|
||||||
|
",".join(map(str, gpu_ids)) if isinstance(gpu_ids, (list, tuple))
|
||||||
|
else str(gpu_ids)
|
||||||
|
)
|
||||||
if "worker_cls" in kwargs:
|
if "worker_cls" in kwargs:
|
||||||
worker_cls = kwargs["worker_cls"]
|
worker_cls = kwargs["worker_cls"]
|
||||||
# if the worker_cls is not qualified string name,
|
# if the worker_cls is not qualified string name,
|
||||||
|
Reference in New Issue
Block a user