mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fused RMSNorm implementation (#153666)
Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c8844d9e7
commit
e1aee86646
@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
|
||||
torch.native_dropout: lambda input, p, train: -1,
|
||||
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||
torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
|
||||
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
||||
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
|
||||
torch.native_channel_shuffle: lambda input, groups: -1,
|
||||
|
Reference in New Issue
Block a user