Compare commits

...

19 Commits

Author SHA1 Message Date
2882e9037b Fix output-json 2025-10-28 13:53:58 -07:00
a58b89df44 Fix linting 2025-10-28 12:54:07 -07:00
c684aa411b Merge remote-tracking branch 'origin/main' into attention_benchmark 2025-10-28 10:20:14 -07:00
fcb1bf53cb Merge remote-tracking branch 'origin/main' into attention_benchmark 2025-10-27 21:14:03 -07:00
df9bd5ce7c Fix linting 2025-10-27 21:09:17 -07:00
82b70f9a7f Update safe-backend use 2025-10-27 08:59:52 -07:00
2bcd29cc48 Handle nan 2025-10-26 21:36:18 -07:00
f7f410a1f5 Clean fav2 2025-10-26 21:06:24 -07:00
1832d9720f Format config logic now 2025-10-23 14:16:19 -07:00
4069d76684 Add print-config 2025-10-23 13:47:46 -07:00
11536e5a6b Update json 2025-10-23 10:21:01 -07:00
9d8778b8bf Lint and minor fixes 2025-10-20 08:58:32 -07:00
86279c6f25 Merge remote-tracking branch 'origin/main' into attention_benchmark 2025-10-19 16:08:23 -07:00
93553121d8 Update score_mod 2025-10-09 15:30:13 -07:00
e5eb96af95 Add json 2025-10-06 09:52:59 -07:00
4407b6c9e3 Transformer benchmarks 2025-10-06 08:30:52 -07:00
22ea056fcd Update score_mod 2025-09-30 14:00:51 -07:00
d7466fd5c6 Add config files 2025-09-30 10:36:04 -07:00
de61804393 Add attention benchmarking 2025-09-29 09:35:01 -07:00
3 changed files with 665 additions and 61 deletions

View File

@ -0,0 +1,157 @@
"""Configuration utilities for parsing JSON and YAML config files."""
import json
import re
def heads_input_type(s: str) -> tuple[int, int]:
"""Convert string format 'Hq,Hkv' to tuple (Hq, Hkv)."""
try:
hq, hkv = map(int, s.split(","))
return hq, hkv
except Exception as e:
raise ValueError("Heads must be Hq,Hkv") from e
default_config = {
"dynamic": False,
"calculate_bwd": False,
"dtype": "bfloat16",
"b": [2, 8, 16],
"nh": ["16,16", "16,2"],
"s": [512, 1024, 4096],
"d": [64, 128],
"mods": ["noop", "causal", "alibi", "sliding_window"],
"backend": ["efficient"],
"max_autotune": False,
"decoding": False,
"kv_size": None,
"throughput": True,
"save_path": None,
"output_json_for_dashboard": None,
"benchmark_name": "PyTorch operator microbenchmark",
}
def load_config_file(config_path: str) -> dict:
"""Load configuration from JSON or YAML file.
Automatically converts 'nh' field from strings to tuples.
Args:
config_path: Path to the configuration file
Returns:
Dictionary containing the configuration
Raises:
FileNotFoundError: If config file doesn't exist
ValueError: If config file format is invalid
"""
with open(config_path) as f:
config_str = f.read()
# Try to load as JSON first
try:
config = json.loads(config_str)
except json.JSONDecodeError:
# Fall back to YAML parsing
config = _parse_simple_yaml(config_str)
# Apply automatic conversions for 'nh' field
if "nh" in config and isinstance(config["nh"], list):
config["nh"] = [
heads_input_type(h) if isinstance(h, str) else h for h in config["nh"]
]
return config
def _parse_simple_yaml(yaml_str: str) -> dict:
"""Simple YAML parser for basic configs (without external dependencies).
Supports:
- key: value pairs
- booleans (true/false)
- null values
- integers and floats
- strings (quoted and unquoted)
- lists in JSON format [item1, item2, ...]
- comments (lines starting with # or after #)
Args:
yaml_str: YAML content as string
Returns:
Dictionary containing parsed YAML content
"""
config = {}
for line in yaml_str.split("\n"):
# Remove comments
line = line.split("#")[0].strip()
if not line or ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
# Parse value based on type
if value.lower() == "true":
config[key] = True
elif value.lower() == "false":
config[key] = False
elif value.lower() in ("null", "none", ""):
config[key] = None
elif value.startswith("[") and value.endswith("]"):
# Parse list - handle quoted strings properly
pattern = r'"([^"]+)"|\'([^\']+)\'|([^,\[\]\s]+)'
matches = re.findall(pattern, value[1:-1]) # Remove [ ]
parsed_items = []
for match in matches:
# match is a tuple of (double_quoted, single_quoted, unquoted)
item = match[0] or match[1] or match[2]
item = item.strip()
if item:
try:
parsed_items.append(int(item))
except ValueError:
parsed_items.append(item)
config[key] = parsed_items
elif value.startswith(('"', "'")):
config[key] = value.strip("\"'")
else:
# Try to parse as number
try:
config[key] = int(value)
except ValueError:
try:
config[key] = float(value)
except ValueError:
config[key] = value
return config
def print_default_config(output_format: str) -> None:
"""Print a default configuration template in JSON or YAML format.
Args:
output_format: Either "json" or "yaml"
"""
if output_format == "json":
print(json.dumps(default_config, indent=2))
else: # yaml
for key, value in default_config.items():
if value is None:
print(f"{key}: null")
elif isinstance(value, bool):
print(f"{key}: {str(value).lower()}")
elif isinstance(value, str):
print(f'{key}: "{value}"')
elif isinstance(value, list):
print(f"{key}: {json.dumps(value)}")
else:
print(f"{key}: {value}")

View File

@ -0,0 +1,29 @@
# Basic benchmark configuration for PyTorch transformer benchmarks
# Usage: python score_mod.py --config config_basic.yaml
# Core parameters
dynamic: false
calculate_bwd: true
dtype: "bfloat16"
# Shape parameters - larger sweep
b: [1, 2, 4, 8, 16] # batch sizes
nh: ["16,16", "16,2", "32,32", "32,4"] # [query_heads,key_value_heads]
s: [512, 1024, 2048, 4096, 8192] # sequence lengths
d: [64, 128] # head dimensions (limited to 128 for Flash Attention/cuDNN compatibility)
# All attention types
mods: ["noop", "causal", "rel", "head_bias", "alibi", "sliding_window", "prefix_lm", "softcap"]
# Multiple backends for comparison (SDPA + Flash Attention) - flex is always included internally
backend: ["efficient", "math", "cudnn", "fav2"]
max_autotune: true # Enable torch.compile with max-autotune for optimal performance
# Decoding and cache settings
decoding: false
kv_size: null
# Metrics and output
throughput: true # Calculate memory bandwidth & TFLOPS
save_path: "comprehensive_results.csv" # Save to CSV
output_json_for_dashboard: "attn_bench_basic.json"

View File

@ -1,15 +1,19 @@
import argparse
import csv
import gc
import itertools
import json
import random
import sys
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial
from typing import Optional, Union
from functools import partial, wraps
from typing import Literal, Optional, Union
import numpy as np
from config_utils import heads_input_type, load_config_file, print_default_config
from tabulate import tabulate
from tqdm import tqdm
@ -33,6 +37,96 @@ torch._dynamo.config.recompile_limit = 1000
from torch._inductor.runtime.benchmarking import benchmarker
def cleanup_memory():
"""Aggressively free GPU memory"""
torch.cuda.empty_cache()
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
def safe_backend(backend_name=None, return_dict=False):
"""Decorator that wraps backend functions with error handling
Args:
backend_name: Name of the backend for error messages
return_dict: If True, returns dict of results for all backends (for run_single_experiment)
If False, returns single ExperimentResults (for individual backend functions)
"""
def decorator(func):
@wraps(func)
def wrapper(config, *args, **kwargs):
try:
return func(config, *args, **kwargs)
except torch.OutOfMemoryError:
print(
f"[SKIP] OOM for {backend_name or func.__name__} with shape {config.shape}"
)
cleanup_memory()
except RuntimeError as e:
error_msg = str(e)
if "out of resource" in error_msg or "OutOfMemoryError" in error_msg:
print(
f"[SKIP] Triton OOM for {backend_name or func.__name__} with shape {config.shape}"
)
cleanup_memory()
elif "No valid triton configs" in error_msg:
print(
f"[SKIP] No valid Triton config for {backend_name or func.__name__} with shape {config.shape}"
)
else:
print(
f"[SKIP] Runtime error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}"
)
except Exception as e:
print(
f"[SKIP] Error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}"
)
# Return appropriate NaN result based on function type
if return_dict:
# For run_single_experiment: return dict with NaN for all backends
nan_result = ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
)
results = dict.fromkeys(config.backends, nan_result)
results["flex"] = ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
sparsity=None,
)
return results
else:
# For individual backend functions: return single ExperimentResults
return ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
)
return wrapper
return decorator
# Type definitions
Backend = Literal["math", "efficient", "cudnn", "fav2", "fav3", "fakv", "og-eager"]
AttentionType = Literal[
"noop",
"causal",
"rel",
"head_bias",
"alibi",
"sliding_window",
"document_mask",
"prefix_lm",
"softcap",
]
DtypeString = Literal["bfloat16", "float16", "float32"]
SpeedupType = Literal["fwd", "bwd"]
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
@ -48,6 +142,7 @@ class ExperimentConfig:
calculate_bwd_time: bool
cal_bandwidth: bool
backends: list[str]
max_autotune: bool
def __post_init__(self):
assert len(self.shape) == 6, (
@ -62,6 +157,7 @@ class ExperimentConfig:
d.pop("cal_bandwidth", None)
d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape")
d.pop("backends", None)
d.pop("max_autotune", False)
return d
@ -209,6 +305,7 @@ def query_key_value_clones(
return query_ref, key_ref, value_ref
@safe_backend("SDPA")
def run_single_backend_sdpa(
config: ExperimentConfig,
query: torch.Tensor,
@ -223,6 +320,7 @@ def run_single_backend_sdpa(
backend_context = get_backend_context(backend)
with backend_context:
_device = torch.device("cuda")
eager_sdpa = generate_eager_sdpa(
config.attn_type, config.shape, config.dtype, block_mask, score_mod
)
@ -290,6 +388,7 @@ def run_single_backend_sdpa(
)
@safe_backend("FlashAttention")
def run_single_backend_FA(
config: ExperimentConfig,
query: torch.Tensor,
@ -301,9 +400,9 @@ def run_single_backend_FA(
mask_kwargs,
backend: str,
) -> ExperimentResults:
assert backend in ["fav2", "fav3", "fakv"]
assert backend in ["fav3", "fakv"]
# Generate callable for specific backend.
if backend in ["fav2", "fav3"]:
if backend in ["fav3"]:
FA = generate_FA_callable(
config.attn_type, config.shape, config.dtype, backend, **mask_kwargs
)
@ -354,10 +453,10 @@ def run_single_backend_FA(
)
@safe_backend("flex_attention", return_dict=True)
def run_single_experiment(
config: ExperimentConfig,
dynamic=False,
max_autotune=False,
) -> dict[str, ExperimentResults]:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
@ -377,7 +476,7 @@ def run_single_experiment(
block_mask, mask_kwargs = generate_block_mask(config.attn_type, config.shape)
kernel_options = get_kernel_options(config.attn_type, config.shape)
if max_autotune:
if config.max_autotune:
compiled_sdpa = torch.compile(
flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
)
@ -407,7 +506,7 @@ def run_single_experiment(
results = {}
for backend in config.backends:
if backend in ["fav2", "fav3", "fakv"]:
if backend in ["fav3", "fakv"]:
results[backend] = run_single_backend_FA(
config,
query,
@ -419,7 +518,7 @@ def run_single_experiment(
mask_kwargs,
backend,
)
else: # sdpa
else: # sdpa (also supports fav2)
results[backend] = run_single_backend_sdpa(
config,
query,
@ -440,7 +539,7 @@ def run_single_experiment(
sparsity = block_mask.sparsity() / 100.0 if block_mask is not None else 0.0
sparsity = sparsity if config.attn_type != "document_mask" else 0.5
results["compiled"] = ExperimentResults(
results["flex"] = ExperimentResults(
fwd_time=forward_compiled_time,
bwd_time=backward_compile_time if config.calculate_bwd_time else None,
sparsity=sparsity,
@ -501,15 +600,15 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
softmax_flops = M * N * 2 # Not counting online softmax overhead
o_flops = M * D * N * 2
# Not counting split k overhead
total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - results.sparsity)
sparsity = results.sparsity if results.sparsity is not None else 0.0
total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - sparsity)
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
def get_average_speedups(results: list[Experiment], type: str, backend: str):
# Calculate speedups
speedups = [
calculate_speedup(r.results["compiled"], r.results[backend], type)
for r in results
calculate_speedup(r.results["flex"], r.results[backend], type) for r in results
]
# Find indices of max and min speedups
@ -537,7 +636,7 @@ def get_average_speedups(results: list[Experiment], type: str, backend: str):
def print_results(results: list[Experiment], save_path: Optional[str] = None):
table_data = defaultdict(list)
for experiment in results:
backends = experiment.config.backends + ["compiled"]
backends = experiment.config.backends + ["flex"]
for key, value in experiment.asdict().items():
if key in backends:
if value.fwd_time:
@ -550,45 +649,43 @@ def print_results(results: list[Experiment], save_path: Optional[str] = None):
# Calculate speedups
for backend in results[0].config.backends:
fwd_speedups = [
calculate_speedup(r.results["compiled"], r.results[backend], type="fwd")
calculate_speedup(r.results["flex"], r.results[backend], type="fwd")
for r in results
]
table_data[f"fwd_{backend}_speedup"] = fwd_speedups
table_data[f"fwd_speedup_flex_over_{backend}"] = fwd_speedups
if results[0].config.calculate_bwd_time:
for backend in results[0].config.backends:
bwd_speedups = [
calculate_speedup(r.results["compiled"], r.results[backend], type="bwd")
calculate_speedup(r.results["flex"], r.results[backend], type="bwd")
for r in results
]
table_data[f"bwd_{backend}_speedup"] = bwd_speedups
table_data[f"bwd_speedup_flex_over_{backend}"] = bwd_speedups
# Calculate mem + computational throughput
if results[0].config.cal_bandwidth:
fwd_bandwidth = [
calculate_bandwidth(r.config, r.results["compiled"], type="fwd")
calculate_bandwidth(r.config, r.results["flex"], type="fwd")
for r in results
]
table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth
fwd_tflops = [
calculate_tflops(r.config, r.results["compiled"]) for r in results
]
fwd_tflops = [calculate_tflops(r.config, r.results["flex"]) for r in results]
table_data["TFlops/s"] = fwd_tflops
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
for backend in results[0].config.backends:
if np.isnan(table_data[f"fwd_{backend}_speedup"]).all():
if np.isnan(table_data[f"fwd_speedup_flex_over_{backend}"]).all():
continue
print("\n")
print(f"FWD Speedups vs. {backend}".center(125, "="))
print(f"FWD Speedup of Flex over {backend}".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="fwd", backend=backend)
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
if results[0].config.calculate_bwd_time:
print("\n")
print(f"BWD Speedups vs. {backend}".center(125, "="))
print(f"BWD Speedup of Flex over {backend}".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="bwd", backend=backend)
print(
@ -791,14 +888,14 @@ def get_backend_context(backend: str):
Returns a context manager for the specified backend.
Args:
backend (str): The name of the backend to use.
Valid options are 'fav2', 'cudnn', 'math', 'efficient', 'fav3', 'fakv', 'og-eager'.
Valid options are 'math', 'efficient', 'cudnn', 'fav2', 'fav3', 'fakv', 'og-eager'.
Returns:
A context manager for the specified backend.
Raises:
ValueError: If an invalid backend is specified.
"""
backends = {
"fav2": nullcontext(),
"fav2": sdpa_kernel(SDPBackend.FLASH_ATTENTION),
"cudnn": sdpa_kernel(SDPBackend.CUDNN_ATTENTION),
"math": sdpa_kernel(SDPBackend.MATH),
"efficient": sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION),
@ -820,15 +917,7 @@ def generate_FA_callable(
) -> Callable | None:
if dtype not in [torch.float16, torch.bfloat16]:
return None
if backend == "fav2":
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
print(
"Flash attention 2 is not installed. Please install it to run fav2 backend. "
)
raise
elif backend == "fav3":
if backend == "fav3":
try:
from flash_attn.flash_attn_interface import (
flash_attn_func,
@ -1034,6 +1123,7 @@ def generate_experiment_configs(
kv_cache_size: list[int],
cal_bandwidth: bool,
backends: list[str],
max_autotune: bool,
) -> list[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
@ -1077,52 +1167,333 @@ def generate_experiment_configs(
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
backends=backends,
max_autotune=max_autotune,
)
)
return all_configs
def main(args):
def _output_json_for_dashboard(
experiments,
output_file,
benchmark_name="PyTorch operator microbenchmark",
):
"""
Write the result into JSON format for PyTorch OSS dashboard.
The JSON format is defined at
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
Args:
experiments: List of experiment results
output_file: Path to output JSON file
benchmark_name: Name of the benchmark
"""
if not experiments:
return
import math
import platform
from dataclasses import asdict, dataclass
from typing import Any, Optional
# Prepare headers and records for JSON output
records = []
for experiment in experiments:
config = experiment.config
results_dict = (
experiment.results
) # This is a dict: backend -> ExperimentResults
# Process each backend result
for backend, results in results_dict.items():
# Skip backends that were not run (NaN results)
if math.isnan(results.fwd_time):
continue
# Extract data from experiment
test_name = f"{backend}_{config.attn_type}_"
input_config = f"shape: {config.shape}, dtype: {config.dtype}"
# Determine mode based on backward pass
mode = "training" if config.calculate_bwd_time else "inference"
# Extract dtype
dtype = (
str(config.dtype).split(".")[1]
if "." in str(config.dtype)
else str(config.dtype)
)
# Determine device
device = "cuda"
# Get device architecture
device_arch = (
torch.cuda.get_device_name(0)
if device == "cuda"
else platform.processor()
if device == "cpu"
else "unknown"
)
# Create dataclasses for JSON structure
@dataclass
class BenchmarkInfo:
name: str
mode: Optional[str]
dtype: str
extra_info: dict[str, Any]
@dataclass
class ModelInfo:
name: str
type: str
origins: list[str]
extra_info: dict[str, Any]
@dataclass
class MetricInfo:
name: str
unit: str
benchmark_values: list[float]
target_value: Optional[float]
@dataclass
class BenchmarkRecord:
benchmark: BenchmarkInfo
model: ModelInfo
metric: MetricInfo
# Benchmark extra info
benchmark_extra_info = {
"input_config": input_config,
"device": device,
"arch": device_arch,
"operator_name": backend,
"attn_type": config.attn_type,
"shape": str(config.shape),
"max_autotune": config.max_autotune,
}
# Add record for forward latency
record_fwd_latency = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
"attn_type": config.attn_type,
},
),
metric=MetricInfo(
name="forward latency",
unit="us",
benchmark_values=[results.fwd_time],
target_value=None,
),
)
records.append(asdict(record_fwd_latency))
# Add record for forward memory bandwidth (if available)
if config.cal_bandwidth:
record_fwd_bandwidth = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
},
),
metric=MetricInfo(
name="memory bandwidth",
unit="TB/s",
benchmark_values=[calculate_bandwidth(config, results, "fwd")],
target_value=None,
),
)
records.append(asdict(record_fwd_bandwidth))
# Add record for forward TFLOPS (if available)
if config.cal_bandwidth:
record_fwd_tflops = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
},
),
metric=MetricInfo(
name="tflops",
unit="TFLOPS/s",
benchmark_values=[calculate_tflops(config, results)],
target_value=None,
),
)
records.append(asdict(record_fwd_tflops))
# Add record for backward latency (if available and not NaN)
if (
config.calculate_bwd_time
and results.bwd_time is not None
and not math.isnan(results.bwd_time)
):
record_bwd_latency = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
},
),
metric=MetricInfo(
name="backward latency",
unit="us",
benchmark_values=[results.bwd_time],
target_value=None,
),
)
records.append(asdict(record_bwd_latency))
# Write all records to the output file
with open(output_file, "w", encoding="utf-8") as f:
json.dump(records, f, indent=2)
def main(
dynamic: bool = False,
calculate_bwd: bool = False,
dtype: DtypeString = "bfloat16",
b: list[int] | None = None,
nh: list[str] | None = None,
s: list[int] | None = None,
d: list[int] | None = None,
mods: list[AttentionType] | None = None,
backend: list[Backend] | None = None,
max_autotune: bool = False,
decoding: bool = False,
kv_size: Optional[list[int]] = None,
throughput: bool = True,
save_path: Optional[str] = None,
output_json_for_dashboard: Optional[str] = None,
benchmark_name: str = "PyTorch operator microbenchmark",
) -> None:
"""Run sweep over sizes and score mods for flex attention.
Usage Examples:
# Use a yml config file
python score_mod.py --config basic_config.yaml
# Use a json config file
python score_mod.py --config my_config.json
# Generate a config template
python score_mod.py --print-config json > my_config.json # For a json config
python score_mod.py --print-config yaml > my_config.yaml # For a yaml config
# Override config with CLI args
python score_mod.py --config my_config.json -dtype float16 --max-autotune
# Pure CLI usage
python score_mod.py -b 4 8 -s 1024 2048 -mods causal alibi --backend efficient
Args:
dynamic: Runs a dynamic shapes version of compiled flex attention
calculate_bwd: Calculate backward pass times
dtype: Data type for tensors (bfloat16, float16, float32)
b: Batch sizes to benchmark
nh: Number of query and key/value heads in format "Hq,Hkv"
s: Sequence lengths to benchmark
d: Head dimensions to benchmark
mods: Score modifications: noop, causal, rel, head_bias, alibi, sliding_window, document_mask, prefix_lm, softcap
backend: Backends for attention computation: math, efficient, cudnn, fav2, fav3, fakv, og-eager
max_autotune: Turn on max-autotune optimization
decoding: Benchmark decoding mode (query sequence length = 1)
kv_size: Key/value cache size in MiB (ignores batch size if specified)
throughput: Calculate kernel memory bandwidth & computational throughput (always True)
save_path: Path to save the results CSV file
output_json_for_dashboard: Path to save results in JSON format for PyTorch OSS dashboard
benchmark_name: Name of the benchmark for dashboard output
"""
# Convert dtype string to torch dtype (if not already converted)
import torch
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
# Always calculate throughput
throughput = True
print("Backend: ", backend)
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(
generate_experiment_configs(
args.calculate_bwd,
args.dtype,
args.b,
args.nh,
args.s,
args.d,
args.mods,
args.decoding,
args.kv_size,
args.throughput,
args.backend,
)
for experiment_count, config in enumerate(
tqdm(
generate_experiment_configs(
calculate_bwd,
dtype,
b,
nh,
s,
d,
mods,
decoding,
kv_size,
throughput,
backend,
max_autotune,
)
),
start=1,
):
results.append(
Experiment(
config,
run_single_experiment(
config,
dynamic=args.dynamic,
max_autotune=args.max_autotune,
dynamic=dynamic,
),
)
)
print_results(results, args.save_path)
# Periodic memory cleanup every 50 experiments
if experiment_count % 50 == 0:
cleanup_memory()
print_results(results, save_path)
def heads_input_type(s):
try:
hq, hkv = map(int, s.split(","))
return hq, hkv
except Exception as e:
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
# Output JSON for dashboard if requested
if output_json_for_dashboard:
_output_json_for_dashboard(results, output_json_for_dashboard, benchmark_name)
if __name__ == "__main__":
@ -1130,6 +1501,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run sweep over sizes and score mods for flex attention"
)
parser.add_argument(
"--config",
type=str,
help="Path to JSON config file. CLI args override config file values.",
default=None,
)
parser.add_argument(
"--dynamic",
action="store_true",
@ -1199,8 +1576,49 @@ Ignores -b batch size and calculate batch size from kv size instead when specifi
default=["efficient"],
help="Backend to use for attention computation",
)
parser.add_argument(
"--output-json-for-dashboard",
type=str,
help="Path to save results in JSON format for PyTorch OSS dashboard",
default=None,
)
parser.add_argument(
"--benchmark-name",
type=str,
help="Name of the benchmark for dashboard output",
default="PyTorch operator microbenchmark",
)
parser.add_argument(
"--print-config",
type=str,
choices=["json", "yaml"],
help="Print a default config template in JSON or YAML format and exit",
default=None,
)
# Parse arguments
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
main(args)
# Handle --print-config
if args.print_config:
print_default_config(args.print_config)
sys.exit(0)
# Load and merge config if provided
if args.config:
config = load_config_file(args.config)
# Merge config with CLI args (CLI args take precedence)
json_args = argparse.Namespace()
json_args.__dict__ = config
args = parser.parse_args(namespace=json_args)
# Convert dtype string to torch dtype (only if it's still a string)
if isinstance(args.dtype, str):
args.dtype = getattr(torch, args.dtype)
# Remove config and print_config from args before passing to main
args_dict = vars(args)
args_dict.pop("config", None)
args_dict.pop("print_config", None)
main(**args_dict)