Revert "Fused RMSNorm implementation (#153666)"

This reverts commit e1aee86646aa6d1b9cb9d34351e43936401c5efc.

Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](e1aee86646) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176))
This commit is contained in:
PyTorch MergeBot
2025-07-01 18:46:45 +00:00
parent 3a5677a380
commit 6401d1d53d
14 changed files with 184 additions and 839 deletions

View File

@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._export.utils import _is_cia_op
from torch._ops import DispatchKey
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
@ -1226,33 +1226,6 @@ class DecompOneOffTests(TestCase):
for o_ref, o in zip(out_ref, out):
self.assertEqual(o_ref.dtype, o.dtype)
@onlyCUDA
@unittest.skipIf(not SM70OrLater, "triton")
def test_rms_norm_decomp_cuda(self, device):
@torch.compile
def rms_norm_sinh(a, b, c):
output = torch.nn.functional.rms_norm(a, b, c)
return torch.sinh(output)
normalized_shape_arg = (3, 3, 3)
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
def forward_pass_fn():
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
forward_pass_fn
)
# check RMSNorm was fused with sinh
self.assertTrue(
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
)
instantiate_device_type_tests(DecompOneOffTests, globals())