[Hardware][CPU] using current_platform.is_cpu (#9536)

This commit is contained in:
wangshuai09
2024-10-22 15:50:43 +08:00
committed by GitHub
parent 0d02747f2e
commit 3ddbe25502
17 changed files with 60 additions and 64 deletions

View File

@ -32,9 +32,10 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu)
identity)
logger = init_logger(__name__)
@ -236,7 +237,8 @@ class HfRunner:
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
if device is None:
return self.wrap_device(input, "cpu" if is_cpu() else "cuda")
return self.wrap_device(
input, "cpu" if current_platform.is_cpu() else "cuda")
if hasattr(input, "device") and input.device.type == device:
return input

View File

@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
@ -35,7 +35,7 @@ def vllm_to_hf_output(
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.skipif(
is_cpu(),
current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models"
)
def test_encoder_decoder_e2e(
@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
enforce_eager: bool,
) -> None:
'''
End-to-End (E2E) test for the encoder-decoder framework.
End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent

View File

@ -19,7 +19,8 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
with patch("vllm.attention.selector.current_platform.is_cpu",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"

View File

@ -5,7 +5,7 @@ Run `pytest tests/models/test_phimoe.py`.
import pytest
import torch
from vllm.utils import is_cpu
from vllm.platforms import current_platform
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
@ -70,7 +70,7 @@ def test_phimoe_routing_function():
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
@pytest.mark.skipif(condition=is_cpu(),
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="This test takes a lot time to run on CPU, "
"and vllm CI's disk space is not enough for this model.")
@large_gpu_test(min_gb=80)

View File

@ -3,8 +3,8 @@ from typing import List, Optional, Tuple, Type
import pytest
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
@ -46,7 +46,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
@ -103,7 +103,7 @@ def run_test(
target_dtype = "half"
if is_cpu():
if current_platform.is_cpu():
target_dtype = "bfloat16"

View File

@ -7,7 +7,7 @@ from PIL.Image import Image
from transformers import AutoConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import is_cpu
from vllm.platforms import current_platform
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
@ -78,7 +78,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
@ -244,7 +244,7 @@ def run_awq_test(
target_dtype = "half"
if is_cpu():
if current_platform.is_cpu():
target_dtype = "bfloat16"

View File

@ -10,8 +10,9 @@ from vllm.inputs import InputContext, token_inputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_hip
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
@ -49,7 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
target_dtype = "half"
if is_cpu():
if current_platform.is_cpu():
target_dtype = "bfloat16"
# ROCm Triton FA can run into shared memory issues with these models,

View File

@ -5,8 +5,8 @@ import torch
from vllm.config import ModelConfig, TaskOption
from vllm.inputs import InputContext
from vllm.platforms import current_platform
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.utils import is_cpu
TokensText = Tuple[List[int], str]
@ -19,7 +19,7 @@ def check_outputs_equal(
name_1: str,
):
"""
Compare the two sequences generated by different models,
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
@ -255,7 +255,7 @@ def build_model_context(model_name: str,
mm_processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
@ -270,7 +270,7 @@ def build_model_context(model_name: str,
if tokenizer_name is None:
tokenizer_name = model_name
if dtype is None:
dtype = "bfloat16" if is_cpu() else "half"
dtype = "bfloat16" if current_platform.is_cpu() else "half"
model_config = ModelConfig(
model_name,

View File

@ -5,8 +5,9 @@ import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_cpu, make_tensor_with_pad
from vllm.utils import make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
@ -31,7 +32,7 @@ def _create_model_runner(model: str, *args,
return model_runner
@pytest.mark.skipif(condition=is_cpu(),
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@ -74,7 +75,7 @@ def test_empty_seq_group():
assert return_seq_lens is None
@pytest.mark.skipif(condition=is_cpu(),
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@ -264,7 +265,7 @@ def test_prepare_prompt(batch_size):
assert torch.equal(actual, expected)
@pytest.mark.skipif(condition=is_cpu(),
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@ -490,7 +491,7 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
enabled, the tensors used during the decode phase are correctly padded
for varying input batch sizes.
"""
model_runner = _create_model_runner(

View File

@ -10,9 +10,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu
from vllm.platforms import current_platform
if is_cpu():
if current_platform.is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
@ -234,10 +234,10 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention

View File

@ -3,7 +3,7 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
@ -32,7 +32,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
):
super().__init__()
if use_spda is None:
use_spda = is_hip() or is_cpu() or not \
use_spda = is_hip() or current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
@ -109,13 +109,13 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify is when q is a mix of
The only case you need to specify is when q is a mix of
prefilling and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
@ -171,7 +171,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""For CPU, V100 or other older GPUs.
NOTE: torch SPDA supports nested tensor,
NOTE: torch SPDA supports nested tensor,
but seems extremely slow. Choose to pad instead.
"""
assert (cu_seqlens_q is None or
@ -201,8 +201,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
return self.transpose_and_unpad(spda_output, cu_seqlens)
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
the type of device used and cuda compute capability.
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
@ -213,8 +213,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify
is when q is a mix of prefilling
The only case you need to specify
is when q is a mix of prefilling
and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).

View File

@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino, is_xpu
logger = init_logger(__name__)
@ -121,7 +121,7 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
assert current_platform.is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
@ -183,7 +183,7 @@ def which_attn_to_use(
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if is_cpu():
if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

View File

@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
@ -37,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, supports_custom_op
from vllm.utils import supports_custom_op
@dataclass
@ -1139,7 +1139,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
if not is_cpu():
if not current_platform.is_cpu():
torch.cuda.empty_cache()

View File

@ -7,7 +7,7 @@ import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once
from vllm.utils import is_hip, is_xpu, print_warning_once
logger = init_logger(__name__)
@ -74,7 +74,7 @@ class CustomOp(nn.Module):
if is_hip():
return self.forward_hip
elif is_cpu():
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_tpu():
return self.forward_tpu

View File

@ -78,7 +78,7 @@ logger = init_logger(__name__)
class Qwen2VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape:
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
@ -102,14 +102,14 @@ Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""

View File

@ -21,7 +21,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_cpu, is_pin_memory_available
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -474,7 +474,7 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
@ -515,7 +515,7 @@ def get_vit_attn_backend() -> _Backend:
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif is_cpu():
elif current_platform.is_cpu():
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS

View File

@ -318,15 +318,6 @@ def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_cpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "cpu" in version("vllm")
except PackageNotFoundError:
return False
@lru_cache(maxsize=None)
def is_openvino() -> bool:
from importlib.metadata import PackageNotFoundError, version
@ -798,7 +789,7 @@ def is_pin_memory_available() -> bool:
elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif is_cpu() or is_openvino():
elif current_platform.is_cpu() or is_openvino():
return False
return True