mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fused RMSNorm implementation (#153666)"
This reverts commit e1aee86646aa6d1b9cb9d34351e43936401c5efc.
Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](e1aee86646
) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176))
This commit is contained in:
@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._export.utils import _is_cia_op
|
||||
from torch._ops import DispatchKey
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
|
||||
from torch.testing._internal.common_cuda import tf32_off
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
@ -1226,33 +1226,6 @@ class DecompOneOffTests(TestCase):
|
||||
for o_ref, o in zip(out_ref, out):
|
||||
self.assertEqual(o_ref.dtype, o.dtype)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
def test_rms_norm_decomp_cuda(self, device):
|
||||
@torch.compile
|
||||
def rms_norm_sinh(a, b, c):
|
||||
output = torch.nn.functional.rms_norm(a, b, c)
|
||||
return torch.sinh(output)
|
||||
|
||||
normalized_shape_arg = (3, 3, 3)
|
||||
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
|
||||
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
|
||||
|
||||
def forward_pass_fn():
|
||||
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
|
||||
|
||||
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
|
||||
forward_pass_fn
|
||||
)
|
||||
|
||||
# check RMSNorm was fused with sinh
|
||||
self.assertTrue(
|
||||
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
|
||||
)
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user