Ensure rms_norm decomp generates add.Scalar for pattern match BC (#165437)

Summary: Apparently if I just do `tensor + eps` this turns into add.Tensor, which is bad because the constant Tensor ends up getting hoisted into an input, which is a bozo thing to do. Just make sure it's exactly compatible.

Test Plan:
```
buck run 'fbcode//mode/opt' fbcode//bolt/nn/executorch/backends/tests:qnn_test_ar1g1 bolt.nn.executorch.backends.tests.qnn_test_ar1g1.QnnTestAR1G1.test_RMSNorm
```

Reviewed By: tugsbayasgalan

Differential Revision: D84613184

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165437
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Edward Yang
2025-10-14 19:56:37 +00:00
committed by PyTorch MergeBot
parent 74acf92648
commit 08f09d9543

View File

@ -1783,7 +1783,10 @@ def _fused_rms_norm(
rqrst_input = torch.rsqrt(
# NB: don't inplace here, will violate functional IR invariant
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
# NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
torch.ops.aten.add.Scalar(
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
)
)
upcasted_result = upcasted_input.mul(rqrst_input)