Compare commits

...

2 Commits

Author SHA1 Message Date
9b0dad69f6 Update on "Fix accuracy for layernorm/rmsnorm benchmarking"
Example command:
    python benchmarks/dynamo/genai_layers/benchmark.py --exit-on-accuracy-failure --tolerance=1e-2 rmsnorm_backward

Fix the accuracy problem for layernorm/rmsnorm fwd/bwd.
Also fix some quack calls (maybe due to quack API change)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-22 11:29:53 -07:00
8e39c95a04 Fix accuracy for layernorm/rmsnorm benchmarking
Example command:
    python benchmarks/dynamo/genai_layers/benchmark.py --exit-on-accuracy-failure --tolerance=1e-2 rmsnorm_backward

Fix the accuracy problem for layernorm/rmsnorm fwd/bwd.
Also fix some quack calls (maybe due to quack API change)

[ghstack-poisoned]
2025-10-21 10:22:19 -07:00
3 changed files with 130 additions and 47 deletions

View File

@ -58,8 +58,7 @@ def list_benchmarks():
def run_benchmark(
benchmark_name: str,
should_visualize: bool = False,
compile_mode: str = "max-autotune-no-cudagraphs",
script_args,
):
"""Run a specific benchmark."""
if benchmark_name not in BENCHMARK_REGISTRY:
@ -68,29 +67,29 @@ def run_benchmark(
return False
print(f"Running benchmark: {benchmark_name}")
print(f"Torch compile mode: {compile_mode}")
print(f"Torch compile mode: {script_args.compile_mode}")
print("=" * 60)
benchmark_class = BENCHMARK_REGISTRY[benchmark_name]
benchmark = benchmark_class(compile_mode)
benchmark = benchmark_class(script_args)
benchmark.benchmark()
if should_visualize:
if script_args.visualize:
benchmark.visualize()
return True
def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"):
def run_all_benchmarks(script_args):
"""Run all available benchmarks."""
print("Running all benchmarks...")
print(f"Torch compile mode: {compile_mode}")
print(f"Torch compile mode: {script_args.compile_mode}")
print("=" * 60)
for name, cls in BENCHMARK_REGISTRY.items():
print(f"\n{'=' * 20} {name.upper()} {'=' * 20}")
benchmark = cls(compile_mode)
benchmark = cls(script_args)
benchmark.benchmark()
if should_visualize:
if script_args.visualize:
benchmark.visualize()
print()
@ -137,6 +136,19 @@ Examples:
help="Torch compile mode to use (default: default)",
)
parser.add_argument(
"--tolerance",
type=float,
default=None,
help="Tolerance for the accuracy check",
)
parser.add_argument(
"--exit-on-accuracy-failure",
action="store_true",
help="Whether to exit with an error message for accuracy failure",
)
args = parser.parse_args()
# Handle list option
@ -146,7 +158,7 @@ Examples:
# Handle all option
if args.all:
run_all_benchmarks(args.visualize, args.compile_mode)
run_all_benchmarks(args)
return
# Handle specific benchmarks
@ -157,7 +169,7 @@ Examples:
sys.exit(1)
for benchmark_name in args.benchmarks:
run_benchmark(benchmark_name, args.visualize, args.compile_mode)
run_benchmark(benchmark_name, args)
print() # Add spacing between benchmarks

View File

@ -9,8 +9,8 @@ import torch.nn.functional as F
class CrossEntropyForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -106,8 +106,8 @@ class CrossEntropyForward(BenchmarkKernel):
class CrossEntropyBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -194,8 +194,8 @@ class CrossEntropyBackward(BenchmarkKernel):
class SoftmaxForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -259,8 +259,8 @@ class SoftmaxForward(BenchmarkKernel):
class SoftmaxBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -329,8 +329,8 @@ class SoftmaxBackward(BenchmarkKernel):
class RMSNormForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -383,7 +383,22 @@ class RMSNormForward(BenchmarkKernel):
from quack.rmsnorm import _rmsnorm_fwd
x, w = args
return lambda: _rmsnorm_fwd(x, w, eps=1e-6)
y = torch.empty_like(x)
def quack_fwd():
_rmsnorm_fwd(
x,
w,
out=y,
bias=None,
rstd=None,
residual=None,
residual_out=None,
eps=1e-6,
)
return y
return quack_fwd
def liger(self, args, kwargs) -> Any:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@ -404,9 +419,14 @@ class RMSNormForward(BenchmarkKernel):
class RMSNormBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = [
"eager",
"compiled",
"quack",
"liger",
]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
# TODO: OOM for (32768, 65536) on h100
@ -454,8 +474,11 @@ class RMSNormBackward(BenchmarkKernel):
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compute_rstd(self, x, eps):
return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
def quack(self, args, kwargs=None) -> Any:
from quack.rmsnorm import _rmsnorm_backward
from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
(
x,
@ -463,15 +486,40 @@ class RMSNormBackward(BenchmarkKernel):
dy,
) = args
M, N = x.shape
rstd = torch.randn(M, device="cuda", dtype=torch.float32)
return lambda: _rmsnorm_backward(x, w, dy, rstd)
rstd = self.compute_rstd(x, eps=1e-6)
dx = torch.empty_like(x)
sm_count = _get_sm_count(x.size(1), x.device)
dw_partial = torch.empty(
sm_count, x.size(1), device=x.device, dtype=torch.float32
)
def quack_bwd():
_rmsnorm_bwd(
x,
w,
dy,
rstd,
dx,
dw_partial,
db_partial=None,
dresidual_out=None,
dresidual=None,
sm_count=sm_count,
)
dw = dw_partial.sum(dim=0).to(w.dtype)
return dx, dw
return quack_bwd
def liger(self, args, kwargs=None) -> Any:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
x, w, dy = args
M, N = x.shape
liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
liger_rmsnorm = LigerRMSNorm(
hidden_size=N, eps=1e-6, casting_mode="gemma"
).cuda()
liger_rmsnorm.weight.data.copy_(w)
y = liger_rmsnorm(x)
return lambda: torch.autograd.grad(
@ -489,8 +537,8 @@ class RMSNormBackward(BenchmarkKernel):
class LayerNormForward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -563,8 +611,8 @@ class LayerNormForward(BenchmarkKernel):
class LayerNormBackward(BenchmarkKernel):
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
super().__init__(compile_mode)
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
@ -614,20 +662,31 @@ class LayerNormBackward(BenchmarkKernel):
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compute_mean_rstd(self, x, eps):
x = x.float()
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
rstd = torch.rsqrt(var + eps)
return mean, rstd
def liger(self, args, kwargs) -> Any:
from liger_kernel.transformers.layer_norm import LigerLayerNorm
"""
Call layer_norm_backward directly rather than calling
liger_kernel.transformers.layer_norm.LigerLayerNorm and
torch.autograd.grad.
The latter fashion saves mean/rstd in x.dtype which can fail
accuracy test. We call layer_norm_backward with fp32 mean and
rstd.
"""
from liger_kernel.ops.layer_norm import layer_norm_backward
x, w, dy = args
eps = 1e-6
mean, rstd = self.compute_mean_rstd(x, eps)
M, N = x.shape
liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
liger_layernorm.weight.data.copy_(w)
liger_layernorm.bias.data.copy_(
torch.zeros(N, device="cuda", dtype=torch.float32)
)
y = liger_layernorm(x)
return lambda: torch.autograd.grad(
y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True
)
return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
def benchmark(self):
for M, N in self.get_shapes():

View File

@ -1,4 +1,5 @@
import os
import sys
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
@ -43,10 +44,11 @@ class Performance:
class BenchmarkKernel:
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
def __init__(self, script_args):
self.script_args = script_args
self.name = self.__class__.__name__
self.available_backends: list[str] = []
self.compile_mode: str = compile_mode
self.compile_mode: str = script_args.compile_mode
# mapping from backend to list of performance results
self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
@ -106,14 +108,21 @@ class BenchmarkKernel:
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
gold = res["eager"]
tol = {}
if self.script_args.tolerance:
tol = {
"atol": self.script_args.tolerance,
"rtol": self.script_args.tolerance,
}
for backend in self.available_backends:
if backend == "eager":
continue
try:
torch.testing.assert_close(res[backend], gold)
torch.testing.assert_close(res[backend], gold, **tol)
for t, gold_t in zip(res[backend], gold):
if t.requires_grad:
torch.testing.assert_close(t.grad, gold_t.grad)
torch.testing.assert_close(t.grad, gold_t.grad, **tol)
print(
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
)
@ -121,6 +130,9 @@ class BenchmarkKernel:
print(
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
)
if self.script_args.exit_on_accuracy_failure:
print("Exit right away since --exit-on-accuracy-failure is set")
sys.exit(1)
def benchmark_single_shape(
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""