mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Hardware][CPU] using current_platform.is_cpu (#9536)
This commit is contained in:
@ -32,9 +32,10 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
identity, is_cpu)
|
||||
identity)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -236,7 +237,8 @@ class HfRunner:
|
||||
|
||||
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
|
||||
if device is None:
|
||||
return self.wrap_device(input, "cpu" if is_cpu() else "cuda")
|
||||
return self.wrap_device(
|
||||
input, "cpu" if current_platform.is_cpu() else "cuda")
|
||||
|
||||
if hasattr(input, "device") and input.device.type == device:
|
||||
return input
|
||||
|
@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import DecoderPromptType
|
||||
from ..models.utils import check_logprobs_close
|
||||
@ -35,7 +35,7 @@ def vllm_to_hf_output(
|
||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
is_cpu(),
|
||||
current_platform.is_cpu(),
|
||||
reason="CPU backend is not currently supported with encoder/decoder models"
|
||||
)
|
||||
def test_encoder_decoder_e2e(
|
||||
@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-End (E2E) test for the encoder-decoder framework.
|
||||
End-to-End (E2E) test for the encoder-decoder framework.
|
||||
This test evaluates the encoder-decoder functionality using the BART
|
||||
model. We compare the outputs of the Hugging Face and vLLM
|
||||
implementations to ensure that both implementations produce consistent
|
||||
|
@ -19,7 +19,8 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
override_backend_env_variable(monkeypatch, name)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.is_cpu", return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform.is_cpu",
|
||||
return_value=True):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
|
@ -5,7 +5,7 @@ Run `pytest tests/models/test_phimoe.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
@ -70,7 +70,7 @@ def test_phimoe_routing_function():
|
||||
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
||||
reason="This test takes a lot time to run on CPU, "
|
||||
"and vllm CI's disk space is not enough for this model.")
|
||||
@large_gpu_test(min_gb=80)
|
||||
|
@ -3,8 +3,8 @@ from typing import List, Optional, Tuple, Type
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ...utils import check_logprobs_close
|
||||
@ -46,7 +46,7 @@ def run_test(
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
@ -103,7 +103,7 @@ def run_test(
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
if is_cpu():
|
||||
if current_platform.is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from PIL.Image import Image
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
@ -78,7 +78,7 @@ def run_test(
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
@ -244,7 +244,7 @@ def run_awq_test(
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
if is_cpu():
|
||||
if current_platform.is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
|
||||
|
@ -10,8 +10,9 @@ from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
@ -49,7 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
if is_cpu():
|
||||
if current_platform.is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
|
@ -5,8 +5,8 @@ import torch
|
||||
|
||||
from vllm.config import ModelConfig, TaskOption
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
@ -19,7 +19,7 @@ def check_outputs_equal(
|
||||
name_1: str,
|
||||
):
|
||||
"""
|
||||
Compare the two sequences generated by different models,
|
||||
Compare the two sequences generated by different models,
|
||||
which should be equal.
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
@ -255,7 +255,7 @@ def build_model_context(model_name: str,
|
||||
mm_processor_kwargs: Optional[Dict] = None,
|
||||
limit_mm_per_prompt: Optional[Dict] = None):
|
||||
"""Creates an InputContext for a given model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model being considered.
|
||||
tokenizer_name: Name of the tokenizer being considered.
|
||||
@ -270,7 +270,7 @@ def build_model_context(model_name: str,
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
if dtype is None:
|
||||
dtype = "bfloat16" if is_cpu() else "half"
|
||||
dtype = "bfloat16" if current_platform.is_cpu() else "half"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_name,
|
||||
|
@ -5,8 +5,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import is_cpu, make_tensor_with_pad
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
@ -31,7 +32,7 @@ def _create_model_runner(model: str, *args,
|
||||
return model_runner
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@ -74,7 +75,7 @@ def test_empty_seq_group():
|
||||
assert return_seq_lens is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@ -264,7 +265,7 @@ def test_prepare_prompt(batch_size):
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@ -490,7 +491,7 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
|
||||
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
||||
"""
|
||||
Tests that for encoder-decoder models with CUDA Graph capture and replay
|
||||
enabled, the tensors used during the decode phase are correctly padded
|
||||
enabled, the tensors used during the decode phase are correctly padded
|
||||
for varying input batch sizes.
|
||||
"""
|
||||
model_runner = _create_model_runner(
|
||||
|
@ -10,9 +10,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if is_cpu():
|
||||
if current_platform.is_cpu():
|
||||
try:
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
except ImportError:
|
||||
@ -234,10 +234,10 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
|
@ -3,7 +3,7 @@ import math
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
@ -32,7 +32,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
if use_spda is None:
|
||||
use_spda = is_hip() or is_cpu() or not \
|
||||
use_spda = is_hip() or current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
@ -109,13 +109,13 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,),
|
||||
indicating segment of samples,
|
||||
cu_seqlens_k: shape=(batch_size + 1,),
|
||||
indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify is when q is a mix of
|
||||
The only case you need to specify is when q is a mix of
|
||||
prefilling and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
@ -171,7 +171,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
|
||||
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""For CPU, V100 or other older GPUs.
|
||||
NOTE: torch SPDA supports nested tensor,
|
||||
NOTE: torch SPDA supports nested tensor,
|
||||
but seems extremely slow. Choose to pad instead.
|
||||
"""
|
||||
assert (cu_seqlens_q is None or
|
||||
@ -201,8 +201,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
return self.transpose_and_unpad(spda_output, cu_seqlens)
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||
`self.spda`(cpu, Volta, Turing or older)based on
|
||||
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||
`self.spda`(cpu, Volta, Turing or older)based on
|
||||
the type of device used and cuda compute capability.
|
||||
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
@ -213,8 +213,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify
|
||||
is when q is a mix of prefilling
|
||||
The only case you need to specify
|
||||
is when q is a mix of prefilling
|
||||
and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
|
@ -10,7 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino, is_xpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -121,7 +121,7 @@ def get_attn_backend(
|
||||
ROCmFlashAttentionBackend)
|
||||
return ROCmFlashAttentionBackend
|
||||
elif backend == _Backend.TORCH_SDPA:
|
||||
assert is_cpu(), RuntimeError(
|
||||
assert current_platform.is_cpu(), RuntimeError(
|
||||
"Torch SDPA backend is only used for the CPU device.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
@ -183,7 +183,7 @@ def which_attn_to_use(
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
if is_cpu():
|
||||
if current_platform.is_cpu():
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch.
|
||||
The typical workflow is:
|
||||
|
||||
- call `init_distributed_environment` to initialize the distributed environment.
|
||||
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
||||
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
||||
initialize the model parallel groups.
|
||||
|
||||
- any code dealing with the distributed stuff
|
||||
@ -37,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, supports_custom_op
|
||||
from vllm.utils import supports_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1139,7 +1139,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
import ray # Lazy import Ray
|
||||
ray.shutdown()
|
||||
gc.collect()
|
||||
if not is_cpu():
|
||||
if not current_platform.is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once
|
||||
from vllm.utils import is_hip, is_xpu, print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -74,7 +74,7 @@ class CustomOp(nn.Module):
|
||||
|
||||
if is_hip():
|
||||
return self.forward_hip
|
||||
elif is_cpu():
|
||||
elif current_platform.is_cpu():
|
||||
return self.forward_cpu
|
||||
elif current_platform.is_tpu():
|
||||
return self.forward_tpu
|
||||
|
@ -78,7 +78,7 @@ logger = init_logger(__name__)
|
||||
class Qwen2VLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape:
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
"""
|
||||
|
||||
@ -102,14 +102,14 @@ Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
|
||||
|
||||
class Qwen2VLVideoInputs(TypedDict):
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
|
||||
|
@ -21,7 +21,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_cpu, is_pin_memory_available
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -474,7 +474,7 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
|
||||
class LLMWrapper(nn.Module):
|
||||
"""
|
||||
To align with the key names of LoRA trained with PEFT, we need to add an
|
||||
To align with the key names of LoRA trained with PEFT, we need to add an
|
||||
additional layer to the llm's implementation.
|
||||
"""
|
||||
|
||||
@ -515,7 +515,7 @@ def get_vit_attn_backend() -> _Backend:
|
||||
"so we use xformers backend instead. You can run "
|
||||
"`pip install flash-attn` to use flash-attention backend.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif is_cpu():
|
||||
elif current_platform.is_cpu():
|
||||
selected_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
@ -318,15 +318,6 @@ def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_cpu() -> bool:
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
try:
|
||||
return "cpu" in version("vllm")
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_openvino() -> bool:
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
@ -798,7 +789,7 @@ def is_pin_memory_available() -> bool:
|
||||
elif is_neuron():
|
||||
print_warning_once("Pin memory is not supported on Neuron.")
|
||||
return False
|
||||
elif is_cpu() or is_openvino():
|
||||
elif current_platform.is_cpu() or is_openvino():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
Reference in New Issue
Block a user