mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI Failure] Disable FlashInfer RoPE to unblock CI (#25299)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -6,8 +6,6 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
|
||||
@ -32,13 +30,15 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# TODO(mgoin): disabled for now due to failures
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
self.use_flashinfer = (self.enabled()
|
||||
and dtype in (torch.float16, torch.bfloat16)
|
||||
and current_platform.is_cuda()
|
||||
and has_flashinfer()
|
||||
and self.head_size in [64, 128, 256, 512])
|
||||
# self.use_flashinfer = (self.enabled()
|
||||
# and dtype in (torch.float16, torch.bfloat16)
|
||||
# and current_platform.is_cuda()
|
||||
# and has_flashinfer()
|
||||
# and self.head_size in [64, 128, 256, 512])
|
||||
self.use_flashinfer = False
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
if not self.use_flashinfer:
|
||||
|
Reference in New Issue
Block a user