mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[GPT-fast benchmark] Add MLP, gather + gemv, gemv micro benchmark (#128002)
Output example: ``` | name | metric | target | actual | |------------------------------|---------------------------|---------|---------| | layer_norm_bfloat16 | memory_bandwidth(GB/s) | 1017 | 1000.01 | | mlp_layer_norm_gelu_bfloat16 | flops_utilization | 0.71 | 0.71 | | gemv_int8 | memory_bandwidth(GB/s) | 990 | 984.06 | | gemv_bfloat16 | memory_bandwidth(GB/s) | 1137 | 1137.92 | | gather_gemv_int8 | memory_bandwidth(GB/s) | 1113 | 1111.09 | | gather_gemv_bfloat16 | memory_bandwidth(GB/s) | 1249 | 1248.15 | ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128002 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
4c84af0f5d
commit
1fb4effe7a
@ -2,12 +2,17 @@ import argparse
|
||||
import csv
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
|
||||
from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
|
||||
from triton.testing import do_bench
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
WARMUP_ITER = 5
|
||||
|
||||
A100_80G_BF16_TFLOPS = 312
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -18,57 +23,179 @@ class Experiment:
|
||||
actual: float
|
||||
|
||||
|
||||
def do_inference(mod, x, num_samples: int = 5):
|
||||
total_time = 0
|
||||
start = -1
|
||||
|
||||
for i in range(start, num_samples):
|
||||
torch.cuda.synchronize("cuda")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
mod(x)
|
||||
|
||||
if i == -1:
|
||||
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
||||
continue
|
||||
|
||||
torch.cuda.synchronize("cuda")
|
||||
total_time += time.perf_counter() - t0
|
||||
|
||||
total_time = total_time / num_samples
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def run_multi_layer_norm():
|
||||
class MultiLayerNorm(nn.Module):
|
||||
def __init__(self, num_layers, normalized_shape, eps=1e-5, bias=True):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.norm_layers = nn.ModuleList(
|
||||
[
|
||||
nn.LayerNorm(normalized_shape, eps=eps, bias=bias)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for layer_norm in self.norm_layers:
|
||||
x = layer_norm(x)
|
||||
return x
|
||||
|
||||
mod = MultiLayerNorm(num_layers=8, normalized_shape=4096).to("cuda")
|
||||
mod = torch.compile(mod)
|
||||
input = torch.randn([512, 1024, 4096], dtype=torch.bfloat16, device="cuda")
|
||||
inference_time = do_inference(mod, input)
|
||||
|
||||
memory_bandwidth = input.numel() * input.dtype.itemsize / inference_time / 1e9
|
||||
|
||||
return [
|
||||
Experiment(
|
||||
"multi_layer_norm", "memory_bandwidth(GB/s)", 92, f"{memory_bandwidth:.02f}"
|
||||
class SimpleMLP(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(input_dim, hidden_dim, dtype=dtype),
|
||||
nn.LayerNorm(hidden_dim, dtype=dtype),
|
||||
nn.Linear(hidden_dim, output_dim, dtype=dtype),
|
||||
nn.LayerNorm(output_dim, dtype=dtype),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_mlp_layer_norm_gelu():
|
||||
dtype_flops_utilization_map = {
|
||||
torch.bfloat16: "0.71",
|
||||
}
|
||||
input_shapes = [1024, 4096, 8192, 16384]
|
||||
intermediate_size = 14336
|
||||
results = []
|
||||
for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
|
||||
flops_utilization = 0
|
||||
for D in input_shapes:
|
||||
mod = SimpleMLP(
|
||||
input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
|
||||
).to("cuda")
|
||||
|
||||
x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
with FlopCounterMode(display=False) as mode:
|
||||
mod(x)
|
||||
|
||||
flops = mode.get_total_flops()
|
||||
|
||||
compiled_mod = torch.compile(mod, dynamic=False)
|
||||
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_mod(x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
|
||||
flops_utilization += us_per_iter * flops / 1e9 / A100_80G_BF16_TFLOPS
|
||||
|
||||
flops_utilization = flops_utilization / len(input_shapes)
|
||||
dtype_str = str(dtype).replace("torch.", "")
|
||||
results.append(
|
||||
Experiment(
|
||||
f"mlp_layer_norm_gelu_{dtype_str}",
|
||||
"flops_utilization",
|
||||
expected_flops_utilization,
|
||||
f"{flops_utilization:.02f}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def run_layer_norm():
|
||||
dtype_memory_bandwidth_map = {
|
||||
torch.bfloat16: "1017",
|
||||
}
|
||||
input_shapes = [1024, 4096, 8192, 16384]
|
||||
BS = 4096
|
||||
results = []
|
||||
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
|
||||
memory_bandwidth = 0
|
||||
for D in input_shapes:
|
||||
mod = nn.LayerNorm(D).to("cuda")
|
||||
|
||||
x = torch.randn(BS, D, device="cuda", dtype=dtype)
|
||||
|
||||
compiled_mod = torch.compile(mod, dynamic=False)
|
||||
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_mod(x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
|
||||
memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
|
||||
|
||||
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
||||
dtype_str = str(dtype).replace("torch.", "")
|
||||
results.append(
|
||||
Experiment(
|
||||
f"layer_norm_{dtype_str}",
|
||||
"memory_bandwidth(GB/s)",
|
||||
expected_memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@torch._inductor.config.patch(coordinate_descent_tuning=True)
|
||||
def run_gather_gemv():
|
||||
E = 8
|
||||
dtype_memory_bandwidth_map = {
|
||||
torch.int8: "1113",
|
||||
torch.bfloat16: "1249",
|
||||
}
|
||||
input_shapes = [1024, 4096, 8192, 16384]
|
||||
results = []
|
||||
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
|
||||
memory_bandwidth = 0
|
||||
for D in input_shapes:
|
||||
|
||||
def gather_gemv(W, score_idxs, x):
|
||||
return W[score_idxs].to(x.dtype) @ x
|
||||
|
||||
W = torch.randn(E, D, D, device="cuda").to(dtype=dtype)
|
||||
x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
score_idxs = torch.tensor([3, 5], device="cuda")
|
||||
|
||||
compiled_fn = torch.compile(gather_gemv, dynamic=False)
|
||||
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_fn(W, score_idxs, x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_fn(W, score_idxs, x)) * 1000
|
||||
memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
|
||||
|
||||
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
||||
dtype_str = str(dtype).replace("torch.", "")
|
||||
results.append(
|
||||
Experiment(
|
||||
f"gather_gemv_{dtype_str}",
|
||||
"memory_bandwidth(GB/s)",
|
||||
expected_memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@torch._inductor.config.patch(coordinate_descent_tuning=True)
|
||||
def run_gemv():
|
||||
dtype_memory_bandwidth_map = {
|
||||
torch.int8: "990",
|
||||
torch.bfloat16: "1137",
|
||||
}
|
||||
input_shapes = [1024, 4096, 8192, 16384]
|
||||
results = []
|
||||
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
|
||||
memory_bandwidth = 0
|
||||
for D in input_shapes:
|
||||
|
||||
def gemv(W, x):
|
||||
return W.to(x.dtype) @ x
|
||||
|
||||
W = torch.randn(D, D, device="cuda").to(dtype=dtype)
|
||||
x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
compiled_fn = torch.compile(gemv, dynamic=False)
|
||||
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_fn(W, x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_fn(W, x)) * 1000
|
||||
memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
|
||||
|
||||
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
||||
dtype_str = str(dtype).replace("torch.", "")
|
||||
results.append(
|
||||
Experiment(
|
||||
f"gemv_{dtype_str}",
|
||||
"memory_bandwidth(GB/s)",
|
||||
expected_memory_bandwidth,
|
||||
f"{memory_bandwidth:.02f}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def output_csv(output_file, headers, row):
|
||||
@ -100,7 +227,10 @@ all_experiments = {
|
||||
run_llama2_7b_int8,
|
||||
run_mixtral_8x7b_int8,
|
||||
# A list of micro-benchmarks.
|
||||
run_multi_layer_norm,
|
||||
run_mlp_layer_norm_gelu,
|
||||
run_layer_norm,
|
||||
run_gather_gemv,
|
||||
run_gemv,
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user