Files
pytorch/benchmarks/dynamo/genai_layers/kernels.py
Shunting Zhang 0db6bcc015 Fix accuracy for layernorm/rmsnorm benchmarking (#166005)
Example command:
    python benchmarks/dynamo/genai_layers/benchmark.py --exit-on-accuracy-failure --tolerance=1e-2 rmsnorm_backward

Fix the accuracy problem for layernorm/rmsnorm fwd/bwd.
Also fix some quack calls (maybe due to quack API change)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166005
Approved by: https://github.com/BoyuanFeng
2025-10-24 18:14:51 +00:00

699 lines
23 KiB
Python

from typing import Any
import cutlass
import cutlass.torch as cutlass_torch
from utils import BenchmarkKernel
import torch
import torch.nn.functional as F
class CrossEntropyForward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
(16384, 131072),
(8192, 262144),
)
def get_memory_bytes(self, args, kwargs) -> int:
# Read x (M*N elements) + read target (M elements) + write loss (M elements)
x, target = args
M, N = x.shape
dtype = x.dtype
return (M * N + M + M) * dtype.itemsize
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, target = args
return lambda: F.cross_entropy(x, target, reduction="none")
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, target = args
# Mark batch size as dynamic for realistic workload
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(target, 0)
# Need `lambda` otherwise torch.compile will not trace the function.
# More discussion: https://github.com/pytorch/pytorch/issues/158455
compiled_cross_entropy = torch.compile(
lambda x, target: F.cross_entropy(x, target, reduction="none"),
mode=self.compile_mode,
fullgraph=True,
)
return lambda: compiled_cross_entropy(x, target)
def quack(self, args, kwargs=None) -> Any:
assert kwargs is None
x, target = args
from quack.cross_entropy import _cross_entropy
return lambda: _cross_entropy(x, target)
def liger(self, args, kwargs=None) -> Any:
assert kwargs is None
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
x, target = args
cross_entropy = LigerCrossEntropyLoss(reduction="none")
return lambda: cross_entropy(x, target)
def benchmark(self):
for M, N in self.get_shapes():
print(f"\n Tensor dimensions: [{M}, {N}]")
# quack requires cutlass dtype
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype)
target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64)
self.benchmark_single_shape((x, target), setting=f"shape: [{M}, {N}]")
def check_accuracy(self, args, kwargs) -> None:
res = {}
for backend in self.available_backends:
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
gold = res["eager"]
for backend in self.available_backends:
if backend == "eager":
continue
if backend == "quack":
# quack's cross_entropy only returns float32 loss output.
# Need to convert it to the same dtype as gold for comparison.
res[backend] = res[backend].to(gold.dtype)
try:
torch.testing.assert_close(res[backend], gold)
print(
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
)
except Exception as e:
print(
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
)
class CrossEntropyBackward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
(16384, 131072),
(8192, 262144),
)
def get_memory_bytes(self, args, kwargs) -> int:
# Read x (M*N elements) + read target (M elements) + read dloss (M elements) + write grad(M*N elements)
x, target, dloss = args
# Memory ba
M, N = x.shape
return (
2 * M * N * x.dtype.itemsize
+ M * target.dtype.itemsize
+ M * dloss.dtype.itemsize
)
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, target, dloss = args
loss = F.cross_entropy(x, target, reduction="none")
return lambda: torch.autograd.grad(
loss, x, grad_outputs=dloss, retain_graph=True
)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, target, dloss = args
compiled_cross_entropy = torch.compile(
lambda x, target: F.cross_entropy(x, target, reduction="none"),
mode=self.compile_mode,
fullgraph=True,
)
loss = compiled_cross_entropy(x, target)
return lambda: torch.autograd.grad(
loss, x, grad_outputs=dloss, retain_graph=True
)
def quack(self, args, kwargs=None) -> Any:
from quack.cross_entropy import cross_entropy
assert kwargs is None
x, target, dloss = args
loss = cross_entropy(x, target)
return lambda: torch.autograd.grad(
loss, x, grad_outputs=dloss, retain_graph=True
)
def liger(self, args, kwargs=None) -> Any:
assert kwargs is None
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
x, target, dloss = args
cross_entropy = LigerCrossEntropyLoss(reduction="none")
loss = cross_entropy(x, target)
return lambda: torch.autograd.grad(
loss, x, grad_outputs=dloss, retain_graph=True
)
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = 0.1 * torch.randn(
M, N, device="cuda", dtype=torch_dtype, requires_grad=True
)
target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64)
dloss = torch.randn(M, device="cuda", dtype=torch.float32)
self.benchmark_single_shape(
(x, target, dloss), setting=f"shape: [{M}, {N}]"
)
class SoftmaxForward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
(16384, 131072),
(8192, 262144),
)
def get_memory_bytes(self, args, kwargs) -> int:
(x,) = args
M, N = x.shape
return 2 * M * N * x.dtype.itemsize
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
(x,) = args
return lambda: F.softmax(x, dim=-1)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
(x,) = args
# Mark batch size as dynamic for realistic workload
torch._dynamo.mark_dynamic(x, 0)
compiled_softmax = torch.compile(
lambda x: F.softmax(x, dim=-1), mode=self.compile_mode, fullgraph=True
)
return lambda: compiled_softmax(x)
def quack(self, args, kwargs=None) -> Any:
from quack.softmax import softmax
assert kwargs is None
(x,) = args
return lambda: softmax(x)
def liger(self, args, kwargs=None) -> Any:
from liger_kernel.transformers.softmax import LigerSoftmax
assert kwargs is None
(x,) = args
softmax = LigerSoftmax().to("cuda")
return lambda: softmax(x)
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype)
self.benchmark_single_shape((x,), setting=f"shape: [{M}, {N}]")
class SoftmaxBackward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
(16384, 131072),
(8192, 262144),
)
def get_memory_bytes(self, args, kwargs) -> int:
# Memory: read dy and y, write ax backward
x, dy = args
M, N = x.shape
return 3 * M * N * x.dtype.itemsize
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, dy = args
y = F.softmax(x, dim=-1)
return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, dy = args
compiled_softmax = torch.compile(
lambda x: F.softmax(x, dim=-1), mode=self.compile_mode, fullgraph=True
)
y = compiled_softmax(x)
return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True)
def quack(self, args, kwargs=None) -> Any:
from quack.softmax import softmax
assert kwargs is None
x, dy = args
y = softmax(x)
return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True)
def liger(self, args, kwargs=None) -> Any:
from liger_kernel.transformers.softmax import LigerSoftmax
assert kwargs is None
x, dy = args
softmax = LigerSoftmax().to("cuda")
y = softmax(x)
return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True)
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = 0.1 * torch.randn(
M, N, device="cuda", dtype=torch_dtype, requires_grad=True
)
dy = torch.randn(M, N, device="cuda", dtype=torch_dtype)
self.benchmark_single_shape((x, dy), setting=f"shape: [{M}, {N}]")
class RMSNormForward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
(16384, 131072),
(8192, 262144),
)
def get_memory_bytes(self, args, kwargs) -> int:
x, w = args
M, N = x.shape
return 2 * M * N * x.dtype.itemsize + N * w.dtype.itemsize
def rms_norm_ref(self, x, w):
x_f32 = x.float()
return (
x_f32
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
* w
).to(x.dtype)
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w = args
return lambda: self.rms_norm_ref(x, w)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w = args
# Mark batch size as dynamic for realistic workload
torch._dynamo.mark_dynamic(x, 0)
compiled_rms_norm = torch.compile(
self.rms_norm_ref, mode=self.compile_mode, fullgraph=True
)
return lambda: compiled_rms_norm(x, w)
def quack(self, args, kwargs=None) -> Any:
# Note: only supper weight with float32 dtype
from quack.rmsnorm import _rmsnorm_fwd
x, w = args
y = torch.empty_like(x)
def quack_fwd():
_rmsnorm_fwd(
x,
w,
out=y,
bias=None,
rstd=None,
residual=None,
residual_out=None,
eps=1e-6,
)
return y
return quack_fwd
def liger(self, args, kwargs) -> Any:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
x, w = args
M, N = x.shape
liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda()
liger_rmsnorm.weight.data.copy_(w)
return lambda: liger_rmsnorm(x)
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = torch.randn(M, N, device="cuda", dtype=torch_dtype)
w = torch.randn(N, device="cuda", dtype=torch.float32)
self.benchmark_single_shape((x, w), setting=f"shape: [{M}, {N}]")
class RMSNormBackward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = [
"eager",
"compiled",
"quack",
"liger",
]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
# TODO: OOM for (32768, 65536) on h100
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
)
def get_memory_bytes(self, args, kwargs) -> int:
x, w, dy = args
# x, dy: [M, N], w: [N]
M, N = x.shape
# Read x, w, dy, write dx, dw
return 3 * M * N * x.dtype.itemsize + 2 * N * w.dtype.itemsize
def rms_norm_ref(self, x, w):
x_f32 = x.float()
return (
x_f32
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
* w
).to(x.dtype)
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w, dy = args
y = self.rms_norm_ref(x, w)
return lambda: torch.autograd.grad(
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w, dy = args
y = torch.compile(self.rms_norm_ref, mode=self.compile_mode, fullgraph=True)(
x, w
)
return lambda: torch.autograd.grad(
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compute_rstd(self, x, eps):
return torch.rsqrt(torch.mean(x.float().square(), dim=-1, keepdim=True) + eps)
def quack(self, args, kwargs=None) -> Any:
from quack.rmsnorm import _get_sm_count, _rmsnorm_bwd
(
x,
w,
dy,
) = args
M, N = x.shape
rstd = self.compute_rstd(x, eps=1e-6)
dx = torch.empty_like(x)
sm_count = _get_sm_count(x.size(1), x.device)
dw_partial = torch.empty(
sm_count, x.size(1), device=x.device, dtype=torch.float32
)
def quack_bwd():
_rmsnorm_bwd(
x,
w,
dy,
rstd,
dx,
dw_partial,
db_partial=None,
dresidual_out=None,
dresidual=None,
sm_count=sm_count,
)
dw = dw_partial.sum(dim=0).to(w.dtype)
return dx, dw
return quack_bwd
def liger(self, args, kwargs=None) -> Any:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
x, w, dy = args
M, N = x.shape
liger_rmsnorm = LigerRMSNorm(
hidden_size=N, eps=1e-6, casting_mode="gemma"
).cuda()
liger_rmsnorm.weight.data.copy_(w)
y = liger_rmsnorm(x)
return lambda: torch.autograd.grad(
y, [x, liger_rmsnorm.weight], grad_outputs=dy, retain_graph=True
)
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = torch.randn(M, N, device="cuda", dtype=torch_dtype, requires_grad=True)
w = torch.randn(N, device="cuda", dtype=torch.float32, requires_grad=True)
dy = torch.randn(M, N, device="cuda", dtype=torch_dtype)
self.benchmark_single_shape((x, w, dy), setting=f"shape: [{M}, {N}]")
class LayerNormForward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "quack", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
# OOM for (16384, 131072) on h100
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
)
def get_memory_bytes(self, args, kwargs) -> int:
x, w = args
M, N = x.shape
# Read x ([M, N]), w ([N]), write y ([M, N])
return 2 * M * N * x.dtype.itemsize + N * w.dtype.itemsize
def layernorm_ref(self, x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
x_f32 = x.float()
return F.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w = args
return lambda: self.layernorm_ref(x, w)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w = args
# Mark batch size as dynamic for realistic workload
torch._dynamo.mark_dynamic(x, 0)
compiled_layernorm = torch.compile(
self.layernorm_ref, mode=self.compile_mode, fullgraph=True
)
return lambda: compiled_layernorm(x, w, eps=1e-6)
def quack(self, args, kwargs) -> Any:
# Note: quack layernorm does not support bias
from quack.layernorm import layernorm
x, w = args
return lambda: layernorm(x, w, eps=1e-6)
def liger(self, args, kwargs) -> Any:
from liger_kernel.transformers.layer_norm import LigerLayerNorm
x, w = args
M, N = x.shape
liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda()
liger_layernorm.weight.data.copy_(w)
liger_layernorm.bias.data.copy_(
torch.zeros(N, device="cuda", dtype=torch.float32)
)
return lambda: liger_layernorm(x)
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = torch.randn(M, N, device="cuda", dtype=torch_dtype)
w = torch.randn(N, device="cuda", dtype=torch.float32)
self.benchmark_single_shape((x, w), setting=f"shape: [{M}, {N}]")
class LayerNormBackward(BenchmarkKernel):
def __init__(self, script_args):
super().__init__(script_args)
self.available_backends = ["eager", "compiled", "liger"]
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
# OOM for (16384, 131072), (8192, 262144)
return (
(32768, 256),
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
)
def get_memory_bytes(self, args, kwargs) -> int:
x, w, dy = args
M, N = x.shape
# Read x ([M, N]), w ([N]), dy ([M, N]), write dx ([M, N]), dw ([N])
return (
2 * M * N * x.dtype.itemsize
+ 2 * N * w.dtype.itemsize
+ M * N * dy.dtype.itemsize
)
def layernorm_ref(self, x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
x_f32 = x.float()
return F.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
def eager(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w, dy = args
y = self.layernorm_ref(x, w)
return lambda: torch.autograd.grad(
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compiled(self, args, kwargs=None) -> Any:
assert kwargs is None
x, w, dy = args
compiled_layernorm = torch.compile(
self.layernorm_ref, mode=self.compile_mode, fullgraph=True
)
y = compiled_layernorm(x, w)
return lambda: torch.autograd.grad(
y, [x, w], grad_outputs=dy, retain_graph=True
)
def compute_mean_rstd(self, x, eps):
x = x.float()
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
rstd = torch.rsqrt(var + eps)
return mean, rstd
def liger(self, args, kwargs) -> Any:
"""
Call layer_norm_backward directly rather than calling
liger_kernel.transformers.layer_norm.LigerLayerNorm and
torch.autograd.grad.
The latter fashion saves mean/rstd in x.dtype which can fail
accuracy test. We call layer_norm_backward with fp32 mean and
rstd.
"""
from liger_kernel.ops.layer_norm import layer_norm_backward
x, w, dy = args
eps = 1e-6
mean, rstd = self.compute_mean_rstd(x, eps)
M, N = x.shape
return lambda: layer_norm_backward(dy, x, w, None, mean, rstd)[0:2]
def benchmark(self):
for M, N in self.get_shapes():
print(f"Tensor dimensions: [{M}, {N}]")
torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
x = torch.randn(M, N, device="cuda", dtype=torch_dtype, requires_grad=True)
w = torch.randn(N, device="cuda", dtype=torch.float32, requires_grad=True)
dy = torch.randn(M, N, device="cuda", dtype=torch_dtype)
self.benchmark_single_shape((x, w, dy), setting=f"shape: [{M}, {N}]")