mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156077 Approved by: https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #156069
296 lines
8.2 KiB
Python
296 lines
8.2 KiB
Python
import itertools
|
|
from collections import defaultdict
|
|
from contextlib import nullcontext
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Callable
|
|
|
|
from tabulate import tabulate
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.utils.benchmark as benchmark
|
|
from torch._inductor.utils import do_bench_using_profiling
|
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
|
|
|
|
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
|
"""Thin wrapper around do_bench_using_profiling"""
|
|
|
|
def no_args():
|
|
func(*args, **kwargs)
|
|
|
|
time = do_bench_using_profiling(no_args)
|
|
return time * 1e3
|
|
|
|
|
|
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
|
# warmup
|
|
for _ in range(5):
|
|
func(*args, **kwargs)
|
|
t0 = benchmark.Timer(
|
|
stmt="func(*args, **kwargs)",
|
|
globals={"args": args, "kwargs": kwargs, "func": func},
|
|
)
|
|
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentConfig:
|
|
batch_size: int
|
|
num_heads: int
|
|
q_seq_len: int
|
|
kv_seq_len: int
|
|
embed_dim: int
|
|
is_causal: bool
|
|
dtype: torch.dtype
|
|
backend: SDPBackend
|
|
device: torch.device = torch.device("cuda")
|
|
|
|
@property
|
|
def head_dim(self) -> int:
|
|
return self.embed_dim // self.num_heads
|
|
|
|
def asdict(self):
|
|
dict_obj = asdict(self)
|
|
dict_obj["head_dim"] = self.head_dim
|
|
return dict_obj
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentResults:
|
|
forward_time: float # microseconds
|
|
backward_time: float # microseconds
|
|
forward_tflops: float
|
|
backward_tflops: float
|
|
|
|
def asdict(self):
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Experiment:
|
|
config: ExperimentConfig
|
|
results: ExperimentResults
|
|
|
|
def asdict(self):
|
|
dict1 = self.config.asdict()
|
|
dict2 = self.results.asdict()
|
|
return {**dict1, **dict2}
|
|
|
|
|
|
def calculate_tflops(
|
|
config: ExperimentConfig,
|
|
time_us: float,
|
|
is_backward: bool = False,
|
|
sparsity: float = 0.0,
|
|
) -> float:
|
|
"""
|
|
Calculate TFLOPS for scaled dot product attention.
|
|
|
|
Parameters:
|
|
- config: The experiment configuration
|
|
- time_us: The execution time in microseconds
|
|
- is_backward: Whether to calculate for backward pass (includes gradient computation)
|
|
- sparsity: Sparsity factor between 0.0 and 1.0, where 0.0 means no sparsity and 1.0 means fully sparse
|
|
|
|
Returns:
|
|
- TFLOPS value
|
|
"""
|
|
B = config.batch_size
|
|
H = config.num_heads
|
|
M = config.q_seq_len
|
|
N = config.kv_seq_len
|
|
D = config.head_dim
|
|
|
|
# Calculate density factor (1.0 - sparsity)
|
|
density = 1.0 - sparsity
|
|
|
|
# Forward pass FLOPs
|
|
qk_flops = (
|
|
M * N * D * 2
|
|
) # Q*K^T matmul: (M,D) @ (D,N) with 2 FLOPs per multiply-add
|
|
softmax_flops = M * N * 2 # Softmax operations (exp and div)
|
|
av_flops = (
|
|
M * N * D * 2
|
|
) # Attention @ V: (M,N) @ (N,D) with 2 FLOPs per multiply-add
|
|
|
|
total_flops = B * H * (qk_flops + softmax_flops + av_flops)
|
|
|
|
# Apply density factor to account for sparsity
|
|
total_flops *= density
|
|
|
|
# For backward pass flash uses 2.5x more flops will use this
|
|
if is_backward:
|
|
total_flops *= 2.5
|
|
|
|
# Convert to TFLOPS: flops / (time_us * 1e-6) / 1e12
|
|
tflops = total_flops / (time_us * 1e-6) / 1e12
|
|
|
|
return tflops
|
|
|
|
|
|
def get_input(
|
|
config: ExperimentConfig,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q = torch.randn(
|
|
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
|
|
dtype=config.dtype,
|
|
device=config.device,
|
|
requires_grad=True,
|
|
)
|
|
k = torch.randn(
|
|
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
|
|
dtype=config.dtype,
|
|
device=config.device,
|
|
requires_grad=True,
|
|
)
|
|
v = torch.randn(
|
|
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
|
|
dtype=config.dtype,
|
|
device=config.device,
|
|
requires_grad=True,
|
|
)
|
|
return q, k, v
|
|
|
|
|
|
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
|
q, k, v = get_input(config)
|
|
is_causal = config.is_causal
|
|
context = (
|
|
sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
|
|
)
|
|
with context:
|
|
forward_time = benchmark_cuda_function_in_microseconds(
|
|
scaled_dot_product_attention,
|
|
q,
|
|
k,
|
|
v,
|
|
is_causal=is_causal,
|
|
attn_mask=None,
|
|
)
|
|
out_torch = scaled_dot_product_attention(
|
|
q, k, v, is_causal=is_causal, attn_mask=None
|
|
)
|
|
d_out = torch.randn_like(out_torch)
|
|
backward_time = benchmark_cuda_function_in_microseconds(
|
|
out_torch.backward, d_out, retain_graph=True
|
|
)
|
|
|
|
# Calculate TFLOPS for forward and backward passes
|
|
sparsity = 0.5 if is_causal else 0.0
|
|
forward_tflops = calculate_tflops(config, forward_time, sparsity=sparsity)
|
|
backward_tflops = calculate_tflops(
|
|
config, backward_time, is_backward=True, sparsity=sparsity
|
|
)
|
|
|
|
return ExperimentResults(
|
|
forward_time=forward_time,
|
|
backward_time=backward_time,
|
|
forward_tflops=forward_tflops,
|
|
backward_tflops=backward_tflops,
|
|
)
|
|
|
|
|
|
def print_results(experiments: list[Experiment]):
|
|
table_data = defaultdict(list)
|
|
for experiment in experiments:
|
|
for key, value in experiment.asdict().items():
|
|
table_data[key].append(value)
|
|
del table_data["device"]
|
|
if table_data["backend"][0] is None:
|
|
del table_data["backend"]
|
|
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
|
|
|
|
|
|
def write_results_to_csv(
|
|
experiments: list[Experiment], output_dir: str = "benchmark_results"
|
|
):
|
|
"""
|
|
Write experiment results to a CSV file in the specified directory.
|
|
The filename includes a timestamp for uniqueness.
|
|
"""
|
|
import csv
|
|
import os
|
|
from datetime import datetime
|
|
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Generate filename with timestamp
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = os.path.join(output_dir, f"benchmark_results_{timestamp}.csv")
|
|
|
|
# Get all fields from the first experiment
|
|
if not experiments:
|
|
return
|
|
|
|
fieldnames = list(experiments[0].asdict().keys())
|
|
if "device" in fieldnames:
|
|
fieldnames.remove("device") # Remove device field as it's always cuda
|
|
|
|
# Write results to CSV
|
|
with open(filename, "w", newline="") as csvfile:
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for experiment in experiments:
|
|
row = experiment.asdict()
|
|
if "device" in row:
|
|
del row["device"] # Remove device field
|
|
writer.writerow(row)
|
|
|
|
print(f"Results written to: {filename}")
|
|
|
|
|
|
def generate_experiment_configs() -> list[ExperimentConfig]:
|
|
batch_sizes = [1, 8, 16]
|
|
num_heads = [16]
|
|
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024), (8192, 8192)]
|
|
embed_dims = [2048]
|
|
backends = [None] # If set to None, all backends are enabled
|
|
dtypes = [
|
|
torch.bfloat16,
|
|
]
|
|
is_causal = [True, False]
|
|
all_configs = []
|
|
for (
|
|
bsz,
|
|
heads,
|
|
(q_seq_len, kv_seq_len),
|
|
embed_dim,
|
|
causal,
|
|
dtype,
|
|
backend,
|
|
) in itertools.product(
|
|
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends
|
|
):
|
|
all_configs.append(
|
|
ExperimentConfig(
|
|
batch_size=bsz,
|
|
num_heads=heads,
|
|
q_seq_len=q_seq_len,
|
|
kv_seq_len=kv_seq_len,
|
|
embed_dim=embed_dim,
|
|
is_causal=causal,
|
|
dtype=dtype,
|
|
backend=backend,
|
|
)
|
|
)
|
|
|
|
return all_configs
|
|
|
|
|
|
def main():
|
|
seed = 123
|
|
torch.manual_seed(seed)
|
|
results = []
|
|
for config in tqdm(generate_experiment_configs()):
|
|
results.append(Experiment(config, run_single_experiment(config)))
|
|
|
|
print_results(results)
|
|
write_results_to_csv(results, "../benchmark_results")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|