diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 319fa938d40..235df1a77c5 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -454,7 +454,7 @@ class XIELU(CustomOp): ) return result.view(original_shape) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward_native(self, input: torch.Tensor) -> torch.Tensor: if self._xielu_cuda_obj is not None and input.is_cuda: if not torch._dynamo.is_compiling(): return self._xielu_cuda_fn(input) @@ -464,6 +464,9 @@ class XIELU(CustomOp): ) return self._xielu_python(input) + def forward_cuda(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_native(input) + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f80ba3a7aa2..89676f98cb0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1593,7 +1593,7 @@ class FusedMoE(CustomOp): else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward( + def forward_native( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1627,6 +1627,13 @@ class FusedMoE(CustomOp): return (shared_output[..., :og_hidden_states], fused_output[..., :og_hidden_states]) + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(hidden_states, router_logits) + def forward_impl_chunked( self, full_hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index cd888b73342..7ac2e4bb6c3 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -88,7 +88,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -129,3 +129,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): query = query_rot key = key_rot return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 3d8da0fa9d8..27e41dd0fa9 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -111,7 +111,7 @@ class DualChunkRotaryEmbedding(CustomOp): device=self.device) return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -161,6 +161,15 @@ class DualChunkRotaryEmbedding(CustomOp): dim=-1) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native(positions, query, key, offsets) + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 05322e56f26..4960c20f406 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -12,7 +12,7 @@ from .mrope import MRotaryEmbedding class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. 3D is t:time h:height w:width""" - def forward( + def forward_native( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, @@ -70,3 +70,11 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + + def forward_cuda( # type: ignore[override] + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key) \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 415a85ab698..37ead43e22b 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -53,7 +53,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) return cache - def forward( + def forward_native( # type: ignore[override] self, query: torch.Tensor, key: Optional[torch.Tensor] = None, @@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) + + def forward_cuda( # type: ignore[override] + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0ab4bc5375d..0acb5ea7424 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import numpy as np import torch from transformers import PretrainedConfig -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .base import RotaryEmbedding @@ -202,28 +201,6 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 - self.use_triton = current_platform.is_cuda_alike() - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """MRope forward. - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ - if self.use_triton: - return self.forward_cuda(positions, query, key) - else: - return self.forward_native(positions, query, key) - def forward_native( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 15e628177b3..c915ebac91c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp): param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0]:].data.fill_(0) - def forward(self, input_): + def forward_native(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( @@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp): output = tensor_model_parallel_all_reduce(output_parallel) return output + def forward_cuda(self, input_): + return self.forward_native(input_) + def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" s += f", embedding_dim={self.embedding_dim}"