[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)

Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
Konrad Zawora
2025-09-11 19:15:01 +02:00
committed by GitHub
parent 1fdd5c42d7
commit 4aa23892d6
8 changed files with 53 additions and 30 deletions

View File

@ -454,7 +454,7 @@ class XIELU(CustomOp):
)
return result.view(original_shape)
def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
@ -464,6 +464,9 @@ class XIELU(CustomOp):
)
return self._xielu_python(input)
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_native(input)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.

View File

@ -1593,7 +1593,7 @@ class FusedMoE(CustomOp):
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@ -1627,6 +1627,13 @@ class FusedMoE(CustomOp):
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,

View File

@ -88,7 +88,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
@ -129,3 +129,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query = query_rot
key = key_rot
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(positions, query, key, offsets)

View File

@ -111,7 +111,7 @@ class DualChunkRotaryEmbedding(CustomOp):
device=self.device)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
@ -161,6 +161,15 @@ class DualChunkRotaryEmbedding(CustomOp):
dim=-1)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(positions, query, key, offsets)
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:

View File

@ -12,7 +12,7 @@ from .mrope import MRotaryEmbedding
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
def forward(
def forward_native( # type: ignore[override]
self,
positions: torch.Tensor,
query: torch.Tensor,
@ -70,3 +70,11 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda( # type: ignore[override]
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(positions, query, key)

View File

@ -53,7 +53,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
return cache
def forward(
def forward_native( # type: ignore[override]
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
def forward_cuda( # type: ignore[override]
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(query, key)

View File

@ -8,7 +8,6 @@ import numpy as np
import torch
from transformers import PretrainedConfig
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from .base import RotaryEmbedding
@ -202,28 +201,6 @@ class MRotaryEmbedding(RotaryEmbedding):
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
self.use_triton = current_platform.is_cuda_alike()
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""MRope forward.
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
if self.use_triton:
return self.forward_cuda(positions, query, key)
else:
return self.forward_native(positions, query, key)
def forward_native(
self,
positions: torch.Tensor,

View File

@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp):
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
def forward_native(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp):
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def forward_cuda(self, input_):
return self.forward_native(input_)
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"