[ROCm][Feature] Enable Pipeline Parallelism with Ray Compiled Graph on ROCm (#24275)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-09-09 18:27:35 -05:00
committed by GitHub
parent fb1a8f932a
commit 73e688cb79
3 changed files with 16 additions and 3 deletions

View File

@ -104,6 +104,7 @@ COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1
ENV TOKENIZERS_PARALLELISM=false
# ENV that can improve safe tensor loading, and end-to-end time

View File

@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9'
boto3
botocore
datasets
ray>=2.10.0,<2.45.0
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
peft
pytest-asyncio
tensorizer==2.10.1

View File

@ -78,6 +78,7 @@ if TYPE_CHECKING:
from argparse import Namespace
from vllm.config import ModelConfig, VllmConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@ -1472,7 +1473,8 @@ def current_stream() -> torch.cuda.Stream:
# is hurting performance. Therefore creating a dedicated stream
# per process
if current_platform.is_rocm():
_current_stream_tls.value = torch.cuda.Stream()
# torch.cuda.set_stream here is the alias of _pathed_set_stream
torch.cuda.set_stream(torch.cuda.Stream())
elif current_platform.is_cpu():
_current_stream_tls.value = _StreamPlaceholder()
else:
@ -2278,7 +2280,8 @@ def weak_ref_tensor(tensor: Any) -> Any:
def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor],
IntermediateTensors]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
@ -2290,6 +2293,15 @@ def weak_ref_tensors(
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
# For IntermediateTensors used in pipeline parallelism
from vllm.sequence import IntermediateTensors
if isinstance(tensors, IntermediateTensors):
ret = IntermediateTensors({
key: weak_ref_tensor(val)
for key, val in tensors.tensors.items()
})
return ret
raise ValueError("Invalid type for tensors")