mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
@ -454,7 +454,7 @@ class XIELU(CustomOp):
|
||||
)
|
||||
return result.view(original_shape)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||||
if not torch._dynamo.is_compiling():
|
||||
return self._xielu_cuda_fn(input)
|
||||
@ -464,6 +464,9 @@ class XIELU(CustomOp):
|
||||
)
|
||||
return self._xielu_python(input)
|
||||
|
||||
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(input)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
@ -1593,7 +1593,7 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@ -1627,6 +1627,13 @@ class FusedMoE(CustomOp):
|
||||
return (shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states])
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
return self.forward_native(hidden_states, router_logits)
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
|
@ -88,7 +88,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@ -129,3 +129,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
@ -111,7 +111,7 @@ class DualChunkRotaryEmbedding(CustomOp):
|
||||
device=self.device)
|
||||
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@ -161,6 +161,15 @@ class DualChunkRotaryEmbedding(CustomOp):
|
||||
dim=-1)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
|
@ -12,7 +12,7 @@ from .mrope import MRotaryEmbedding
|
||||
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
||||
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
|
||||
|
||||
def forward(
|
||||
def forward_native( # type: ignore[override]
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@ -70,3 +70,11 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda( # type: ignore[override]
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return self.forward_native(positions, query, key)
|
@ -53,7 +53,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
def forward_native( # type: ignore[override]
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
||||
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
||||
return query_out.type_as(query), key_out.type_as(key)
|
||||
|
||||
def forward_cuda( # type: ignore[override]
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return self.forward_native(query, key)
|
||||
|
@ -8,7 +8,6 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
@ -202,28 +201,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
self.use_triton = current_platform.is_cuda_alike()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""MRope forward.
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
if self.use_triton:
|
||||
return self.forward_cuda(positions, query, key)
|
||||
else:
|
||||
return self.forward_native(positions, query, key)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp):
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
def forward_native(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp):
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
def forward_cuda(self, input_):
|
||||
return self.forward_native(input_)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
|
Reference in New Issue
Block a user