fix gemma3 results all zero (#17364)

Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
Qiming Zhang
2025-04-29 09:40:25 -07:00
committed by GitHub
parent a39203f99e
commit d3cf61b89b

View File

@ -241,7 +241,10 @@ class GemmaRMSNorm(CustomOp):
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = x + residual
if orig_dtype == torch.float16:
x = x + residual.float()
else:
x = x + residual
residual = x
x = x.float()