mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Use torch.compile for GemmaRMSNorm (#7642)
This commit is contained in:
@ -114,10 +114,12 @@ class GemmaRMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
@ -127,17 +129,32 @@ class GemmaRMSNorm(CustomOp):
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x * (1.0 + weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return self.forward_static(self.weight.data, self.variance_epsilon, x,
|
||||
residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
Reference in New Issue
Block a user