mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
ab153be252 |
@ -338,6 +338,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
|
||||
|
||||
# used to control the visible devices in the distributed setting
|
||||
"VLLM_VISIBLE_DEVICES":
|
||||
lambda: os.environ.get("VLLM_VISIBLE_DEVICES", None),
|
||||
|
||||
# timeout for each iteration in the engine
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
|
||||
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
|
||||
|
@ -135,7 +135,14 @@ class Worker(WorkerBase):
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
|
||||
device_id = self.local_rank
|
||||
if envs.VLLM_VISIBLE_DEVICES is not None:
|
||||
devices = [
|
||||
int(dev) for dev in (x.strip() for x in envs.VLLM_VISIBLE_DEVICES.split(','))
|
||||
]
|
||||
device_id = devices[self.local_rank]
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
|
Reference in New Issue
Block a user