diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 506f1b408ae7..b1ac83c740c5 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1783,7 +1783,10 @@ def _fused_rms_norm( rqrst_input = torch.rsqrt( # NB: don't inplace here, will violate functional IR invariant - torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val) + # NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp + torch.ops.aten.add.Scalar( + torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val + ) ) upcasted_result = upcasted_input.mul(rqrst_input)