mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Ensure rms_norm decomp generates add.Scalar for pattern match BC (#165437)
Summary: Apparently if I just do `tensor + eps` this turns into add.Tensor, which is bad because the constant Tensor ends up getting hoisted into an input, which is a bozo thing to do. Just make sure it's exactly compatible. Test Plan: ``` buck run 'fbcode//mode/opt' fbcode//bolt/nn/executorch/backends/tests:qnn_test_ar1g1 bolt.nn.executorch.backends.tests.qnn_test_ar1g1.QnnTestAR1G1.test_RMSNorm ``` Reviewed By: tugsbayasgalan Differential Revision: D84613184 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165437 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
74acf92648
commit
08f09d9543
@ -1783,7 +1783,10 @@ def _fused_rms_norm(
|
||||
|
||||
rqrst_input = torch.rsqrt(
|
||||
# NB: don't inplace here, will violate functional IR invariant
|
||||
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
|
||||
# NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
|
||||
torch.ops.aten.add.Scalar(
|
||||
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
|
||||
)
|
||||
)
|
||||
|
||||
upcasted_result = upcasted_input.mul(rqrst_input)
|
||||
|
Reference in New Issue
Block a user