Fused RMSNorm implementation (#153666)

Relevant #72643

Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090.

```py
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype))
        x_normed = x / (rms_x + self.eps)
        return self.scale * x_normed

def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16):
    rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype)
    input_data = torch.randn(input_shape, device='cuda', dtype=dtype)

    for _ in range(warmup_iterations):
        _ = rms_norm_layer(input_data)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_iterations):
        _ = rms_norm_layer(input_data)

    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iterations

    print(f"--- RMSNorm CUDA Benchmark ---")
    print(f"Input Shape: {input_shape}")
    print(f"Normalized Dimension: {normalized_dim}")
    print(f"Benchmark Iterations: {num_iterations}")
    print(f"--- Fused Implementation ---")
    print(f"Average Time per Iteration: {avg_time_ms:.4f} ms")
    print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms")

    compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda()
    for _ in range(warmup_iterations):
        _ = compiled_rms_norm(input_data)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_iterations):
        _ = compiled_rms_norm(input_data)
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iterations

    print(f"--- TorchCompile Implementation ---")
    print(f"Average Time per Iteration: {avg_time_ms:.4f} ms")
    print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms")

    print("-" * 50)

if __name__ == '__main__':
    parameter_sets = [
        {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16},
        {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16},
        {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16},
        {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32},
        {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16},
    ]

    num_benchmark_iterations = 200
    num_warmup_iterations = 20

    for params in parameter_sets:
        batch_size = params['batch_size']
        sequence_length = params['sequence_length']
        hidden_features = params['hidden_features']
        data_type = params.get('dtype', torch.float16)

        shape = (batch_size, sequence_length, hidden_features)
        norm_dim_to_normalize = hidden_features

        print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}")
        benchmark_rmsnorm_cuda(input_shape=shape,
                               normalized_dim=norm_dim_to_normalize,
                               num_iterations=num_benchmark_iterations,
                               warmup_iterations=num_warmup_iterations,
                               dtype=data_type)
```

Here are the triton compile tests ran on a 5090 (comparing this branch vs main)
```py
import torch
import torch.nn as nn
from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code

torch.manual_seed(0)

device = torch.device("cuda")

for batch in range(0, 9):
    for i in range(9, 16):
        normalized_shape_arg = (2**batch, 2**i)
        input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True)
        weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True)

        model = torch.nn.functional.rms_norm
        compiled_model = torch.compile(model)
        loss = torch.randn_like(input_tensor)

        num_iter = 5
        for j in range(num_iter):
            output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor)
            output.backward(loss)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        num_iter = 10
        for j in range(num_iter):
            output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor)
            output.backward(loss)

        end_event.record()
        torch.cuda.synchronize()

        elapsed_time_ms = start_event.elapsed_time(end_event)
        avg_time_ms = round(elapsed_time_ms / num_iter, 5)
        print(2**batch, 2**i, avg_time_ms)
```
main
```
32 512 0.1812
32 1024 0.19021
32 2048 0.18871
32 4096 0.17019
32 8192 0.21944
32 16384 0.38871
32 32768 0.83282
64 512 0.14705
64 1024 0.13987
64 2048 0.14111
64 4096 0.21699
64 8192 0.43141
64 16384 0.90652
64 32768 2.18573
128 512 0.19361
128 1024 0.1963
128 2048 0.20122
128 4096 0.38888
128 8192 0.93795
128 16384 2.23437
128 32768 5.50079
256 512 0.16722
256 1024 0.22856
256 2048 0.39421
256 4096 0.96621
256 8192 2.48746
256 16384 5.53571
256 32768 11.97932
```
current branch
```
32 512 0.16328
32 1024 0.18104
32 2048 0.15508
32 4096 0.14356
32 8192 0.20111
32 16384 0.45974
32 32768 0.94799
64 512 0.16874
64 1024 0.18701
64 2048 0.16107
64 4096 0.20152
64 8192 0.46568
64 16384 0.96599
64 32768 2.21661
128 512 0.14982
128 1024 0.15565
128 2048 0.22241
128 4096 0.46128
128 8192 0.88883
128 16384 2.3097
128 32768 5.84448
256 512 0.14346
256 1024 0.2007
256 2048 0.45927
256 4096 0.87876
256 8192 2.10571
256 16384 5.73948
256 32768 12.98581
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/albanD
This commit is contained in:
AaronWang04
2025-07-18 23:24:21 +00:00
committed by PyTorch MergeBot
parent 60b9b06a53
commit 15ef4f28df
19 changed files with 845 additions and 185 deletions

View File

@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._export.utils import _is_cia_op
from torch._ops import DispatchKey
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
@ -1226,6 +1226,33 @@ class DecompOneOffTests(TestCase):
for o_ref, o in zip(out_ref, out):
self.assertEqual(o_ref.dtype, o.dtype)
@onlyCUDA
@unittest.skipIf(not SM70OrLater, "triton")
def test_rms_norm_decomp_cuda(self, device):
@torch.compile
def rms_norm_sinh(a, b, c):
output = torch.nn.functional.rms_norm(a, b, c)
return torch.sinh(output)
normalized_shape_arg = (3, 3, 3)
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
def forward_pass_fn():
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
forward_pass_fn
)
# check RMSNorm was fused with sinh
self.assertTrue(
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
)
instantiate_device_type_tests(DecompOneOffTests, globals())