mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
fix gemma3 results all zero (#17364)
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
@ -241,7 +241,10 @@ class GemmaRMSNorm(CustomOp):
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
if orig_dtype == torch.float16:
|
||||
x = x + residual.float()
|
||||
else:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
|
Reference in New Issue
Block a user