Files
pytorch/benchmarks/transformer/score_mod.py
jainapurva a9b29caeae Add attention benchmarking numbers to pytorch operator microbenchmarks (#164155)
This pull request introduces a standardized YAML-based configuration system for transformer attention benchmarks, making it easier to run and manage comprehensive performance tests. It adds example configs, and a wrapper script to convert YAML configs into CLI arguments for the benchmark runner.

#### Next Steps:
CI Enablement: This change would further lead to running the attention ops in CI for regression tracking.

#### Developer flow: (Run locally)
`python score_mod.py --config configs/config_test.yaml`

#### Enabling CI run: https://github.com/pytorch/pytorch/pull/165915

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164155
Approved by: https://github.com/jbschlosser
2025-10-28 23:46:04 +00:00

1625 lines
52 KiB
Python

import argparse
import csv
import gc
import itertools
import json
import random
import sys
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial, wraps
from typing import Literal, Optional, Union
import numpy as np
from config_utils import heads_input_type, load_config_file, print_default_config
from tabulate import tabulate
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
create_mask,
flex_attention,
noop_mask,
)
torch._dynamo.config.automatic_dynamic_shapes = False
# Needed since changing args to function causes recompiles
torch._dynamo.config.recompile_limit = 1000
from torch._inductor.runtime.benchmarking import benchmarker
def cleanup_memory():
"""Aggressively free GPU memory"""
torch.cuda.empty_cache()
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
def safe_backend(backend_name=None, return_dict=False):
"""Decorator that wraps backend functions with error handling
Args:
backend_name: Name of the backend for error messages
return_dict: If True, returns dict of results for all backends (for run_single_experiment)
If False, returns single ExperimentResults (for individual backend functions)
"""
def decorator(func):
@wraps(func)
def wrapper(config, *args, **kwargs):
try:
return func(config, *args, **kwargs)
except torch.OutOfMemoryError:
print(
f"[SKIP] OOM for {backend_name or func.__name__} with shape {config.shape}"
)
cleanup_memory()
except RuntimeError as e:
error_msg = str(e)
if "out of resource" in error_msg or "OutOfMemoryError" in error_msg:
print(
f"[SKIP] Triton OOM for {backend_name or func.__name__} with shape {config.shape}"
)
cleanup_memory()
elif "No valid triton configs" in error_msg:
print(
f"[SKIP] No valid Triton config for {backend_name or func.__name__} with shape {config.shape}"
)
else:
print(
f"[SKIP] Runtime error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}"
)
except Exception as e:
print(
f"[SKIP] Error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}"
)
# Return appropriate NaN result based on function type
if return_dict:
# For run_single_experiment: return dict with NaN for all backends
nan_result = ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
)
results = dict.fromkeys(config.backends, nan_result)
results["flex"] = ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
sparsity=None,
)
return results
else:
# For individual backend functions: return single ExperimentResults
return ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
)
return wrapper
return decorator
# Type definitions
Backend = Literal["math", "efficient", "cudnn", "fav2", "fav3", "fakv", "og-eager"]
AttentionType = Literal[
"noop",
"causal",
"rel",
"head_bias",
"alibi",
"sliding_window",
"document_mask",
"prefix_lm",
"softcap",
]
DtypeString = Literal["bfloat16", "float16", "float32"]
SpeedupType = Literal["fwd", "bwd"]
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) * 1e3
@dataclass(frozen=True)
class ExperimentConfig:
shape: tuple[int] # [B, Hq, M, Hkv, N, D]
attn_type: str
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
backends: list[str]
max_autotune: bool
def __post_init__(self):
assert len(self.shape) == 6, (
"Shape must be of length 6"
) # [B, Hq, M, Hkv, N, D]
def asdict(self):
# Convert the dataclass instance to a dictionary
d = asdict(self)
# Remove the 'calculate_bwd_time' and `cal_bandwidth` key
d.pop("calculate_bwd_time", None)
d.pop("cal_bandwidth", None)
d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape")
d.pop("backends", None)
d.pop("max_autotune", False)
return d
@dataclass(frozen=True)
class Times:
eager_time: float
compiled_time: float
@dataclass(frozen=True)
class ExperimentResults:
fwd_time: float
bwd_time: Optional[float]
sparsity: Optional[float] = None
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: dict[str, ExperimentResults] # backend -> ExperimentResults
def asdict(self):
dict1 = self.config.asdict()
dict2 = self.results
return {**dict1, **dict2}
def generate_inputs(
batch_size: int,
q_heads: int,
q_sequence_length: int,
kv_heads: int,
kv_sequence_length: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
nested_tensors: bool = False,
):
torch.manual_seed(0)
q_shape = (batch_size, q_sequence_length, q_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, kv_heads * head_dim)
assert q_heads % kv_heads == 0
make_q = partial(
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
make_kv = partial(
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
if nested_tensors:
query = (
make_q()
.view(1, q_sequence_length * batch_size, q_heads, head_dim)
.transpose(1, 2)
)
key = (
make_kv()
.view(1, batch_size * kv_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
value = (
make_kv()
.view(1, batch_size * kv_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
else:
query = (
make_q()
.view(batch_size, q_sequence_length, q_heads, head_dim)
.transpose(1, 2)
)
key = (
make_kv()
.view(batch_size, kv_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
value = (
make_kv()
.view(batch_size, kv_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
return query, key, value
def generate_jagged_inputs(
shape: tuple[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
offsets: torch.Tensor,
):
B, Hq, M, Hkv, N, D = shape
def offsets_to_lengths(
offsets: torch.Tensor, device: Union[str, torch.device]
) -> torch.tensor:
"""Converts a list of offsets to a list of lengths. Reverse op of attn_gym.masks.document_mask.length_to_offsets
Args:
offsets: A 1D tensor of offsets
device: The device to place the output tensor on
"""
lengths = offsets[1:] - offsets[:-1]
return lengths
flatten_q = query.transpose(1, 2).flatten(start_dim=0, end_dim=1)
flatten_k = key.transpose(1, 2).flatten(start_dim=0, end_dim=1)
flatten_v = value.transpose(1, 2).flatten(start_dim=0, end_dim=1)
q_list = [
flatten_q[offsets[i] : offsets[i + 1]].clone().detach().to(query.dtype)
for i in range(len(offsets) - 1)
]
q = torch.nested.as_nested_tensor(q_list, device=query.device)
k_list = [
flatten_k[offsets[i] : offsets[i + 1]].clone().detach().to(key.dtype)
for i in range(len(offsets) - 1)
]
k = torch.nested.as_nested_tensor(k_list, device=key.device)
v_list = [
flatten_v[offsets[i] : offsets[i + 1]].clone().detach().to(value.dtype)
for i in range(len(offsets) - 1)
]
v = torch.nested.as_nested_tensor(v_list, device=value.device)
return q, k, v
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
@safe_backend("SDPA")
def run_single_backend_sdpa(
config: ExperimentConfig,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out_compile: torch.Tensor,
score_mod: Callable | None,
block_mask: BlockMask | None,
mask_kwargs,
backend: str,
) -> ExperimentResults:
backend_context = get_backend_context(backend)
with backend_context:
_device = torch.device("cuda")
eager_sdpa = generate_eager_sdpa(
config.attn_type, config.shape, config.dtype, block_mask, score_mod
)
if config.attn_type == "document_mask":
q_eager, k_eager, v_eager = generate_jagged_inputs(
config.shape, query, key, value, **mask_kwargs
)
q_eager = q_eager.transpose(1, 2).requires_grad_(query.requires_grad)
k_eager = k_eager.transpose(1, 2).requires_grad_(key.requires_grad)
v_eager = v_eager.transpose(1, 2).requires_grad_(value.requires_grad)
else:
q_eager, k_eager, v_eager = query_key_value_clones(query, key, value)
if eager_sdpa:
try:
out_eager = eager_sdpa(query=q_eager, key=k_eager, value=v_eager)
except RuntimeError as e:
print(
f"[SKIP] SDPA Backend {backend} for shape {config.shape}. \n\t\t\tError encountered: {e} "
)
return ExperimentResults(
fwd_time=float("nan"),
bwd_time=float("nan") if config.calculate_bwd_time else None,
)
if config.attn_type in ["document_mask"]:
flatten_o_eager = torch.cat(torch.unbind(out_eager.transpose(1, 2)))
flatten_o_compile = out_compile.transpose(1, 2).flatten(
start_dim=0, end_dim=1
)
torch.testing.assert_close(
flatten_o_eager, flatten_o_compile, atol=1e-2, rtol=1e-2
)
elif not (
config.attn_type in ["rel", "alibi"]
and config.dtype in [torch.float16, torch.bfloat16]
): # rel has accuracy issue with 16bit floats
torch.testing.assert_close(out_eager, out_compile, atol=1e-2, rtol=1e-2)
if eager_sdpa:
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query=q_eager, key=k_eager, value=v_eager
)
else:
forward_eager_time = float("nan")
if config.calculate_bwd_time:
# TODO: debug backward pass for njt
if eager_sdpa and config.attn_type != "document_mask":
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, d_out, retain_graph=True
)
else:
backward_eager_time = float("nan")
return ExperimentResults(
fwd_time=forward_eager_time,
bwd_time=backward_eager_time,
)
else:
return ExperimentResults(
fwd_time=forward_eager_time,
bwd_time=None,
)
@safe_backend("FlashAttention")
def run_single_backend_FA(
config: ExperimentConfig,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out_compile: torch.Tensor,
score_mod: Callable | None,
block_mask: BlockMask | None,
mask_kwargs,
backend: str,
) -> ExperimentResults:
assert backend in ["fav3", "fakv"]
# Generate callable for specific backend.
if backend in ["fav3"]:
FA = generate_FA_callable(
config.attn_type, config.shape, config.dtype, backend, **mask_kwargs
)
elif backend == "fakv":
FA = generate_FD_callable(config.attn_type, config.shape, config.dtype)
q_FA, k_FA, v_FA = query_key_value_clones(query, key, value)
q_FA, k_FA, v_FA = q_FA.transpose(1, 2), k_FA.transpose(1, 2), v_FA.transpose(1, 2)
if config.attn_type == "document_mask":
q_FA = q_FA.flatten(start_dim=0, end_dim=1)
k_FA = k_FA.flatten(start_dim=0, end_dim=1)
v_FA = v_FA.flatten(start_dim=0, end_dim=1)
if FA:
out_FA = FA(q=q_FA, k=k_FA, v=v_FA)
if config.attn_type in ["document_mask"]:
out_FA_updated = out_FA[None, :, :, :]
else:
out_FA_updated = out_FA
if not (
config.attn_type in ["rel", "alibi"]
and config.dtype in [torch.float16, torch.bfloat16]
):
torch.testing.assert_close(
out_FA_updated, out_compile.transpose(1, 2), atol=1e-2, rtol=1e-2
)
if FA:
forward_FA_time = benchmark_torch_function_in_microseconds(
FA, q=q_FA, k=k_FA, v=v_FA
)
else:
forward_FA_time = float("nan")
if config.calculate_bwd_time:
if FA:
d_out = torch.randn_like(out_FA)
backward_FA_time = benchmark_torch_function_in_microseconds(
out_FA.backward, d_out, retain_graph=True
)
else:
backward_FA_time = float("nan")
return ExperimentResults(
fwd_time=forward_FA_time,
bwd_time=backward_FA_time if config.calculate_bwd_time else None,
)
@safe_backend("flex_attention", return_dict=True)
def run_single_experiment(
config: ExperimentConfig,
dynamic=False,
) -> dict[str, ExperimentResults]:
device = torch.device("cuda")
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
batch_size,
q_heads,
q_seq_len,
kv_heads,
kv_seq_len,
head_dim,
config.dtype,
device,
requires_grad=config.calculate_bwd_time,
nested_tensors=config.attn_type == "document_mask",
)
score_mod = generate_score_mod(config.attn_type, config.shape)
block_mask, mask_kwargs = generate_block_mask(config.attn_type, config.shape)
kernel_options = get_kernel_options(config.attn_type, config.shape)
if config.max_autotune:
compiled_sdpa = torch.compile(
flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
)
else:
compiled_sdpa = torch.compile(flex_attention, dynamic=dynamic)
out_compile = compiled_sdpa(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
kernel_options=kernel_options,
)
forward_compiled_time = benchmark_torch_function_in_microseconds(
compiled_sdpa,
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
kernel_options=kernel_options,
)
results = {}
for backend in config.backends:
if backend in ["fav3", "fakv"]:
results[backend] = run_single_backend_FA(
config,
query,
key,
value,
out_compile,
score_mod,
block_mask,
mask_kwargs,
backend,
)
else: # sdpa (also supports fav2)
results[backend] = run_single_backend_sdpa(
config,
query,
key,
value,
out_compile,
score_mod,
block_mask,
mask_kwargs,
backend,
)
if config.calculate_bwd_time:
d_out = torch.randn_like(out_compile)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, d_out, retain_graph=True
)
sparsity = block_mask.sparsity() / 100.0 if block_mask is not None else 0.0
sparsity = sparsity if config.attn_type != "document_mask" else 0.5
results["flex"] = ExperimentResults(
fwd_time=forward_compiled_time,
bwd_time=backward_compile_time if config.calculate_bwd_time else None,
sparsity=sparsity,
)
return results
def calculate_speedup(
results: ExperimentResults, baseline_results: ExperimentResults, type: str
) -> float:
if type == "fwd":
return baseline_results.fwd_time / results.fwd_time
elif type == "bwd":
assert results.bwd_time is not None
return baseline_results.bwd_time / results.bwd_time
else:
raise ValueError(f"Invalid type {type}")
def calculate_bandwidth(
config: ExperimentConfig, results: ExperimentResults, type: str
) -> float:
B, Hq, M, Hkv, N, D = config.shape
sparsity = results.sparsity if M == 1 else 0.0
if type == "fwd":
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query_size = (
batch_size
* q_heads
* q_seq_len
* head_dim
* torch.finfo(config.dtype).bits
/ 8
)
kv_size = (
batch_size
* kv_heads
* kv_seq_len
* head_dim
* torch.finfo(config.dtype).bits
/ 8
* 2
)
output_size = query_size
total_size = (
query_size + kv_size * (1 - sparsity) + output_size
) / 1e9 # In GB
time_in_seconds = results.fwd_time / 1e6
return total_size / time_in_seconds / 1e3
else:
raise ValueError(f"Invalid type {type}")
def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> float:
(B, Hq, M, Hkv, N, D) = config.shape
qk_flops = M * N * D * 2
softmax_flops = M * N * 2 # Not counting online softmax overhead
o_flops = M * D * N * 2
# Not counting split k overhead
sparsity = results.sparsity if results.sparsity is not None else 0.0
total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - sparsity)
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
def get_average_speedups(results: list[Experiment], type: str, backend: str):
# Calculate speedups
speedups = [
calculate_speedup(r.results["flex"], r.results[backend], type) for r in results
]
# Find indices of max and min speedups
max_speedup_index = np.nanargmax(speedups)
min_speedup_index = np.nanargmin(speedups)
# Get the config dictionaries
max_config_dict = results[max_speedup_index].config.asdict()
min_config_dict = results[min_speedup_index].config.asdict()
# Create table data
table_data = [
{
"Type": "Average",
"Speedup": np.nanmean(speedups),
**dict.fromkeys(max_config_dict),
},
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
]
return table_data
def print_results(results: list[Experiment], save_path: Optional[str] = None):
table_data = defaultdict(list)
for experiment in results:
backends = experiment.config.backends + ["flex"]
for key, value in experiment.asdict().items():
if key in backends:
if value.fwd_time:
table_data[f"fwd_{key}"].append(float(value.fwd_time))
if value.bwd_time:
table_data[f"bwd_{key}"].append(float(value.bwd_time))
else:
table_data[key].append(value)
# Calculate speedups
for backend in results[0].config.backends:
fwd_speedups = [
calculate_speedup(r.results["flex"], r.results[backend], type="fwd")
for r in results
]
table_data[f"fwd_speedup_flex_over_{backend}"] = fwd_speedups
if results[0].config.calculate_bwd_time:
for backend in results[0].config.backends:
bwd_speedups = [
calculate_speedup(r.results["flex"], r.results[backend], type="bwd")
for r in results
]
table_data[f"bwd_speedup_flex_over_{backend}"] = bwd_speedups
# Calculate mem + computational throughput
if results[0].config.cal_bandwidth:
fwd_bandwidth = [
calculate_bandwidth(r.config, r.results["flex"], type="fwd")
for r in results
]
table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth
fwd_tflops = [calculate_tflops(r.config, r.results["flex"]) for r in results]
table_data["TFlops/s"] = fwd_tflops
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
for backend in results[0].config.backends:
if np.isnan(table_data[f"fwd_speedup_flex_over_{backend}"]).all():
continue
print("\n")
print(f"FWD Speedup of Flex over {backend}".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="fwd", backend=backend)
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
if results[0].config.calculate_bwd_time:
print("\n")
print(f"BWD Speedup of Flex over {backend}".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="bwd", backend=backend)
print(
tabulate(
average_data, headers="keys", tablefmt="github", floatfmt=".3f"
)
)
if save_path is not None:
with open(save_path, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=table_data.keys())
writer.writeheader()
for i in range(len(next(iter(table_data.values())))):
row = {k: v[i] for k, v in table_data.items()}
writer.writerow(row)
print(f"\nResults saved to {save_path}")
# Generate score_mods and BlockMasks
softcap_value = 50
dropout_p = 0.0
def generate_score_mod(attn_type: str, shape: tuple[int]) -> Callable | None:
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap
def relative_bias(score, b, h, m, n):
return score + (m - n)
def head_bias(score, b, h, m, n):
return score + 2 * h
function_dict = {
"noop": None,
"causal": None,
"rel": relative_bias,
"head_bias": head_bias,
"alibi": generate_alibi_bias(Hq),
"sliding_window": None,
"document_mask": None,
"prefix_lm": None,
"softcap": generate_tanh_softcap(softcap_value, approx=True),
}
score_mod = function_dict[attn_type]
is_decoding = M == 1
if is_decoding and score_mod:
offset = torch.tensor(N // 2).to("cuda")
def score_mod_w_offset(score, b, h, m, n):
return score_mod(score, b, h, m + offset, n)
new_score_mod = score_mod_w_offset
else:
new_score_mod = score_mod
return new_score_mod
sliding_window_size = 512
prefix_length = 512
def generate_block_mask(attn_type: str, shape: tuple[int]):
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
def causal(b, h, m, n):
return m >= n
def gen_offset(off):
def offset(b, h, m, n):
return m + off >= n
return offset
from attn_gym.masks import (
generate_doc_mask_mod,
generate_prefix_lm_mask,
generate_sliding_window,
)
from attn_gym.masks.document_mask import length_to_offsets
def generate_random_lengths(total_length, num_documents):
# Initialize all lengths to 1 to ensure each document has at least one token
lengths = [1] * num_documents
remaining_length = total_length - num_documents
# Randomly distribute the remaining length
for _ in range(remaining_length):
index = random.randint(0, num_documents - 1)
lengths[index] += 1
return lengths
mask_mod_kwargs = {}
assert attn_type != "document_mask" or not is_decoding
if attn_type == "document_mask":
random.seed(0)
lengths = generate_random_lengths(N * B, B)
mask_mod_kwargs = dict(offsets=length_to_offsets(lengths, "cuda"))
mask_mod_dict = {
"noop": None,
"causal": causal,
"rel": None,
"head_bias": None,
"alibi": causal,
"sliding_window": generate_sliding_window(sliding_window_size),
"document_mask": partial(generate_doc_mask_mod, mask_mod=causal),
"prefix_lm": generate_prefix_lm_mask(prefix_length),
"softcap": causal,
}
mask_mod = mask_mod_dict[attn_type]
if mask_mod_kwargs:
mask_mod = mask_mod(**mask_mod_kwargs)
if is_decoding and mask_mod:
cached_seq_len = torch.tensor(N // 2).to("cuda")
def decoding_w_cached_seq_len(b, h, m, n):
return mask_mod(b, h, m + cached_seq_len, n)
new_mask_mod = decoding_w_cached_seq_len
else:
new_mask_mod = mask_mod
mask_shape = (1, 1, M, N) if attn_type != "document_mask" else (1, 1, M * B, N * B)
compiled_block_mask = torch.compile(create_block_mask)
if new_mask_mod:
block_mask = compiled_block_mask(new_mask_mod, *mask_shape, "cuda")
else:
block_mask = compiled_block_mask(noop_mask, *mask_shape, "cuda")
return block_mask, mask_mod_kwargs
def get_kernel_options(attn_type: str, shape: tuple[int]):
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
kernel_opt_training_dict = {
"noop": None,
"causal": None,
"rel": None,
"head_bias": None,
"alibi": None,
"sliding_window": None,
"document_mask": {
"BLOCK_N": 32,
"BLOCK_M": 128,
"fwd_num_warps": 8,
"fwd_num_stages": 4,
"BLOCK_M1": 64,
"BLOCK_N1": 64,
"BLOCK_M2": 64,
"BLOCK_N2": 64,
}
if torch.cuda.get_device_capability() >= (8, 0) and D <= 128
else None,
"prefix_lm": None,
"softcap": None,
}
def get_default_split_k(B: int, H: int, Mk: int) -> int:
num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count
"""Heuristic for the number of splits from xformer"""
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = num_SM // bh * 2 # Each SM should at least get one block.
split_k = max(split_k, 1)
return split_k
kernel_opt_decoding_dict = {
"noop": None,
"causal": {"SPLIT_KV": get_default_split_k(B, Hkv, N) * 2},
"rel": None,
"head_bias": None,
"alibi": {"SPLIT_KV": get_default_split_k(B, Hkv, N) * 2},
"sliding_window": None,
"document_mask": None,
"prefix_lm": None,
"softcap": {"SPLIT_KV": get_default_split_k(B, Hkv, N) * 2},
}
return (
kernel_opt_decoding_dict[attn_type]
if is_decoding
else kernel_opt_training_dict[attn_type]
)
# Setup Backend
def get_backend_context(backend: str):
"""
Returns a context manager for the specified backend.
Args:
backend (str): The name of the backend to use.
Valid options are 'math', 'efficient', 'cudnn', 'fav2', 'fav3', 'fakv', 'og-eager'.
Returns:
A context manager for the specified backend.
Raises:
ValueError: If an invalid backend is specified.
"""
backends = {
"fav2": sdpa_kernel(SDPBackend.FLASH_ATTENTION),
"cudnn": sdpa_kernel(SDPBackend.CUDNN_ATTENTION),
"math": sdpa_kernel(SDPBackend.MATH),
"efficient": sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION),
"fav3": nullcontext(),
"fakv": nullcontext(),
"og-eager": nullcontext(),
}
if backend not in backends:
raise ValueError(
f"Unknown backend: {backend}. Valid options are: {', '.join(backends.keys())}"
)
return backends[backend]
def generate_FA_callable(
attn_type: str, shape: tuple[int], dtype: torch.dtype, backend: str, **kwargs
) -> Callable | None:
if dtype not in [torch.float16, torch.bfloat16]:
return None
if backend == "fav3":
try:
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
print(
"Flash attention 3 is not installed. Please install it to run fav3 backend. "
)
raise
else:
print("Unknown backend " + backend)
return None
B, Hq, M, Hkv, N, D = shape
FA_kwargs = {}
if attn_type == "alibi":
h = torch.arange(Hq, dtype=torch.float32, device="cuda")
alibi_slopes = torch.exp2(-((h + 1) * 8.0 / Hq))
FA_kwargs = dict(alibi_slopes=alibi_slopes)
elif attn_type == "document_mask":
FA_kwargs["cu_seqlens_q"] = kwargs["offsets"].to(torch.int32)
FA_kwargs["cu_seqlens_k"] = kwargs["offsets"].to(torch.int32)
def offsets_to_lengths(
offsets: torch.Tensor, device: Union[str, torch.device]
) -> torch.tensor:
lengths = offsets[1:] - offsets[:-1]
return lengths
lengths = offsets_to_lengths(kwargs["offsets"], "cpu")
max_length = torch.max(lengths)
FA_kwargs["max_seqlen_q"] = max_length
FA_kwargs["max_seqlen_k"] = max_length
FA_dict = {
"noop": partial(flash_attn_func, causal=False),
"causal": partial(flash_attn_func, causal=True),
"rel": None,
"head_bias": None,
"alibi": partial(flash_attn_func, causal=True, **FA_kwargs),
"sliding_window": partial(
flash_attn_func, window_size=(sliding_window_size, 0), causal=True
),
"document_mask": partial(flash_attn_varlen_func, causal=True, **FA_kwargs),
"prefix_lm": None,
"softcap": partial(flash_attn_func, softcap=softcap_value, causal=True),
}
return FA_dict[attn_type]
def generate_FD_callable(
attn_type: str, shape: tuple[int], dtype: torch.dtype
) -> Callable | None:
if dtype not in [torch.float16, torch.bfloat16]:
return None
try:
from flash_attn import flash_attn_with_kvcache
except ImportError:
print(
"Flash attention 2 is not installed. Please install it to run fakv backend. "
)
raise
B, Hq, M, Hkv, N, D = shape
assert M == 1
def flash_attn_with_kvcache_renamed(q, k, v, **kwargs):
return flash_attn_with_kvcache(q, k_cache=k, v_cache=v, **kwargs)
FA_kwargs = {}
if attn_type == "alibi":
h = torch.arange(Hq, dtype=torch.float32, device="cuda")
alibi_slopes = torch.exp2(-((h + 1) * 8.0 / Hq))
FA_kwargs = dict(alibi_slopes=alibi_slopes)
FD_dict = {
"noop": partial(flash_attn_with_kvcache_renamed, causal=False),
"causal": partial(flash_attn_with_kvcache_renamed, cache_seqlens=N // 2),
"rel": None,
"head_bias": None,
"alibi": partial(
flash_attn_with_kvcache_renamed, cache_seqlens=N // 2, **FA_kwargs
),
"sliding_window": partial(
flash_attn_with_kvcache_renamed,
cache_seqlens=N // 2,
window_size=(sliding_window_size, 0),
),
"document_mask": None,
"prefix_lm": None,
"softcap": partial(flash_attn_with_kvcache_renamed, softcap=softcap_value),
}
return FD_dict[attn_type]
def generate_attn_mask_linear_score_mod(
shape: tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype
):
B, Hq, M, N = shape
if block_mask is None and score_mod is None:
return None
b = torch.arange(B, dtype=int, device="cuda")
h = torch.arange(Hq, dtype=int, device="cuda")
m = torch.arange(M, dtype=int, device="cuda")
n = torch.arange(N, dtype=int, device="cuda")
score = torch.zeros(B, Hq, M, N, dtype=dtype, device="cuda")
bias = score_mod(
score,
b[:, None, None, None],
h[None, :, None, None],
m[None, None, :, None],
n[None, None, None, :],
)
bool_mask = create_mask(block_mask.mask_mod, B, Hq, M, N, device="cuda")
attn_mask = bias.masked_fill(bool_mask.logical_not(), float("-inf"))
return attn_mask.to(dtype)
def generate_eager_sdpa(
attn_type: str,
shape: tuple[int],
dtype: torch.dtype,
block_mask: BlockMask,
score_mod: Callable | None = None,
**kwargs,
) -> Callable | None:
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
if attn_type == "sliding_window" or attn_type == "prefix_lm":
attn_mask = create_mask(block_mask.mask_mod, 1, 1, M, N, device="cuda")
elif attn_type == "rel":
attn_mask = generate_attn_mask_linear_score_mod(
[1, 1, M, N], block_mask, score_mod, dtype
)
elif attn_type == "head_bias":
h = torch.arange(Hq, dtype=int, device="cuda")
attn_mask = (2 * h[None, :, None, None]).broadcast_to(1, Hq, M, N).to(dtype)
elif attn_type == "alibi":
attn_mask = generate_attn_mask_linear_score_mod(
[1, Hq, M, N], block_mask, score_mod, dtype
)
else:
attn_mask = None
sdpa_dict = {
"noop": partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
),
"causal": partial(
F.scaled_dot_product_attention, is_causal=True, enable_gqa=(Hq != Hkv)
),
"rel": partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
),
"head_bias": partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
),
"alibi": partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
),
"sliding_window": partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
),
"document_mask": partial(
F.scaled_dot_product_attention, is_causal=True, enable_gqa=(Hq != Hkv)
)
if Hq == Hkv
else None,
"prefix_lm": partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
),
"softcap": None,
}
if is_decoding and attn_type == "causal":
attn_mask = create_mask(block_mask.mask_mod, 1, 1, M, N, device="cuda")
sdpa_dict["causal"] = partial(
F.scaled_dot_product_attention, is_causal=False, enable_gqa=(Hq != Hkv)
)
return (
partial(sdpa_dict[attn_type], attn_mask=attn_mask)
if sdpa_dict[attn_type]
else None
)
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: list[int],
num_heads: list[tuple[int, int]],
seq_lens: list[int],
head_dims: list[int],
score_mods_str: list[str],
decoding: bool,
kv_cache_size: list[int],
cal_bandwidth: bool,
backends: list[str],
max_autotune: bool,
) -> list[ExperimentConfig]:
assert not (calculate_bwd and decoding), "Decoding does not support backward"
if decoding:
q_kv_seq_lens = [(1, i) for i in seq_lens] # only testing query length == 1
else:
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
dtypes = [dtype]
all_configs = []
for (
bsz,
(q_heads, kv_heads),
(q_seq_len, kv_seq_len),
head_dim,
attn_type,
dtype,
) in itertools.product(
kv_cache_size if kv_cache_size else batch_sizes,
num_heads,
q_kv_seq_lens,
head_dims,
score_mods_str,
dtypes,
):
if kv_cache_size:
head_size_bytes = torch.finfo(dtype).bits / 8 * head_dim
bsz = int(
(bsz * 1024 * 1024) // (kv_heads * kv_seq_len * head_size_bytes * 2)
)
if bsz <= 0:
continue
assert q_heads % kv_heads == 0
all_configs.append(
ExperimentConfig(
shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
attn_type=attn_type,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
backends=backends,
max_autotune=max_autotune,
)
)
return all_configs
def _output_json_for_dashboard(
experiments,
output_file,
benchmark_name="PyTorch operator microbenchmark",
):
"""
Write the result into JSON format for PyTorch OSS dashboard.
The JSON format is defined at
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
Args:
experiments: List of experiment results
output_file: Path to output JSON file
benchmark_name: Name of the benchmark
"""
if not experiments:
return
import math
import platform
from dataclasses import asdict, dataclass
from typing import Any, Optional
# Prepare headers and records for JSON output
records = []
for experiment in experiments:
config = experiment.config
results_dict = (
experiment.results
) # This is a dict: backend -> ExperimentResults
# Process each backend result
for backend, results in results_dict.items():
# Skip backends that were not run (NaN results)
if math.isnan(results.fwd_time):
continue
# Extract data from experiment
test_name = f"{backend}_{config.attn_type}_"
input_config = f"shape: {config.shape}, dtype: {config.dtype}"
# Determine mode based on backward pass
mode = "training" if config.calculate_bwd_time else "inference"
# Extract dtype
dtype = (
str(config.dtype).split(".")[1]
if "." in str(config.dtype)
else str(config.dtype)
)
# Determine device
device = "cuda"
# Get device architecture
device_arch = (
torch.cuda.get_device_name(0)
if device == "cuda"
else platform.processor()
if device == "cpu"
else "unknown"
)
# Create dataclasses for JSON structure
@dataclass
class BenchmarkInfo:
name: str
mode: Optional[str]
dtype: str
extra_info: dict[str, Any]
@dataclass
class ModelInfo:
name: str
type: str
origins: list[str]
extra_info: dict[str, Any]
@dataclass
class MetricInfo:
name: str
unit: str
benchmark_values: list[float]
target_value: Optional[float]
@dataclass
class BenchmarkRecord:
benchmark: BenchmarkInfo
model: ModelInfo
metric: MetricInfo
# Benchmark extra info
benchmark_extra_info = {
"input_config": input_config,
"device": device,
"arch": device_arch,
"operator_name": backend,
"attn_type": config.attn_type,
"shape": str(config.shape),
"max_autotune": config.max_autotune,
}
# Add record for forward latency
record_fwd_latency = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
"attn_type": config.attn_type,
},
),
metric=MetricInfo(
name="forward latency",
unit="us",
benchmark_values=[results.fwd_time],
target_value=None,
),
)
records.append(asdict(record_fwd_latency))
# Add record for forward memory bandwidth (if available)
if config.cal_bandwidth:
record_fwd_bandwidth = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
},
),
metric=MetricInfo(
name="memory bandwidth",
unit="TB/s",
benchmark_values=[calculate_bandwidth(config, results, "fwd")],
target_value=None,
),
)
records.append(asdict(record_fwd_bandwidth))
# Add record for forward TFLOPS (if available)
if config.cal_bandwidth:
record_fwd_tflops = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
},
),
metric=MetricInfo(
name="tflops",
unit="TFLOPS/s",
benchmark_values=[calculate_tflops(config, results)],
target_value=None,
),
)
records.append(asdict(record_fwd_tflops))
# Add record for backward latency (if available and not NaN)
if (
config.calculate_bwd_time
and results.bwd_time is not None
and not math.isnan(results.bwd_time)
):
record_bwd_latency = BenchmarkRecord(
benchmark=BenchmarkInfo(
name=benchmark_name,
mode=mode,
dtype=dtype,
extra_info=benchmark_extra_info,
),
model=ModelInfo(
name=test_name + str(config.shape),
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
},
),
metric=MetricInfo(
name="backward latency",
unit="us",
benchmark_values=[results.bwd_time],
target_value=None,
),
)
records.append(asdict(record_bwd_latency))
# Write all records to the output file
with open(output_file, "w", encoding="utf-8") as f:
json.dump(records, f, indent=2)
def main(
dynamic: bool = False,
calculate_bwd: bool = False,
dtype: DtypeString = "bfloat16",
b: list[int] | None = None,
nh: list[str] | None = None,
s: list[int] | None = None,
d: list[int] | None = None,
mods: list[AttentionType] | None = None,
backend: list[Backend] | None = None,
max_autotune: bool = False,
decoding: bool = False,
kv_size: Optional[list[int]] = None,
throughput: bool = True,
save_path: Optional[str] = None,
output_json_for_dashboard: Optional[str] = None,
benchmark_name: str = "PyTorch operator microbenchmark",
) -> None:
"""Run sweep over sizes and score mods for flex attention.
Usage Examples:
# Use a yml config file
python score_mod.py --config basic_config.yaml
# Use a json config file
python score_mod.py --config my_config.json
# Generate a config template
python score_mod.py --print-config json > my_config.json # For a json config
python score_mod.py --print-config yaml > my_config.yaml # For a yaml config
# Override config with CLI args
python score_mod.py --config my_config.json -dtype float16 --max-autotune
# Pure CLI usage
python score_mod.py -b 4 8 -s 1024 2048 -mods causal alibi --backend efficient
Args:
dynamic: Runs a dynamic shapes version of compiled flex attention
calculate_bwd: Calculate backward pass times
dtype: Data type for tensors (bfloat16, float16, float32)
b: Batch sizes to benchmark
nh: Number of query and key/value heads in format "Hq,Hkv"
s: Sequence lengths to benchmark
d: Head dimensions to benchmark
mods: Score modifications: noop, causal, rel, head_bias, alibi, sliding_window, document_mask, prefix_lm, softcap
backend: Backends for attention computation: math, efficient, cudnn, fav2, fav3, fakv, og-eager
max_autotune: Turn on max-autotune optimization
decoding: Benchmark decoding mode (query sequence length = 1)
kv_size: Key/value cache size in MiB (ignores batch size if specified)
throughput: Calculate kernel memory bandwidth & computational throughput (always True)
save_path: Path to save the results CSV file
output_json_for_dashboard: Path to save results in JSON format for PyTorch OSS dashboard
benchmark_name: Name of the benchmark for dashboard output
"""
# Convert dtype string to torch dtype (if not already converted)
import torch
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
# Always calculate throughput
throughput = True
print("Backend: ", backend)
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for experiment_count, config in enumerate(
tqdm(
generate_experiment_configs(
calculate_bwd,
dtype,
b,
nh,
s,
d,
mods,
decoding,
kv_size,
throughput,
backend,
max_autotune,
)
),
start=1,
):
results.append(
Experiment(
config,
run_single_experiment(
config,
dynamic=dynamic,
),
)
)
# Periodic memory cleanup every 50 experiments
if experiment_count % 50 == 0:
cleanup_memory()
print_results(results, save_path)
# Output JSON for dashboard if requested
if output_json_for_dashboard:
_output_json_for_dashboard(results, output_json_for_dashboard, benchmark_name)
if __name__ == "__main__":
# Set up the argument parser
parser = argparse.ArgumentParser(
description="Run sweep over sizes and score mods for flex attention"
)
parser.add_argument(
"--config",
type=str,
help="Path to JSON config file. CLI args override config file values.",
default=None,
)
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)
parser.add_argument(
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
)
parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")
parser.add_argument(
"-b", type=int, nargs="+", help="batch sizes", default=[2, 8, 16]
)
parser.add_argument(
"-nh",
type=heads_input_type,
nargs="+",
help="# of q-heads,kv-heads",
default=[(16, 16), (16, 2)],
)
parser.add_argument(
"-s", type=int, nargs="+", help="sequence lengths", default=[512, 1024, 4096]
)
parser.add_argument("-d", type=int, nargs="+", help="head dims", default=[64, 128])
parser.add_argument(
"-mods",
type=str,
nargs="+",
help="score mods: noop, causal, rel, head_bias, alibi, sliding_window, document_mask, prefix_lm, softcap",
default=["noop", "causal", "alibi", "sliding_window"],
)
parser.add_argument(
"--max-autotune", action="store_true", help="Turn on max-autotune"
)
parser.add_argument(
"--decoding",
action="store_true",
help="Benchmark Decoding (query sequence length = 1)",
)
parser.add_argument(
"--kv-size",
type=int,
nargs="+",
required=False,
help="""
key/value size in MiB.
Ignores -b batch size and calculate batch size from kv size instead when specified.
""",
)
parser.add_argument(
"--throughput",
action="store_true",
help="Calculate kernel memory bandwidth & computational throughput. ",
)
parser.add_argument(
"--save-path",
type=str,
help="Path to save the results JSON file (optional)",
default=None,
)
parser.add_argument(
"--backend",
type=str,
nargs="+",
choices=["math", "efficient", "cudnn", "fav2", "fav3", "fakv"],
default=["efficient"],
help="Backend to use for attention computation",
)
parser.add_argument(
"--output-json-for-dashboard",
type=str,
help="Path to save results in JSON format for PyTorch OSS dashboard",
default=None,
)
parser.add_argument(
"--benchmark-name",
type=str,
help="Name of the benchmark for dashboard output",
default="PyTorch operator microbenchmark",
)
parser.add_argument(
"--print-config",
type=str,
choices=["json", "yaml"],
help="Print a default config template in JSON or YAML format and exit",
default=None,
)
# Parse arguments
args = parser.parse_args()
# Handle --print-config
if args.print_config:
print_default_config(args.print_config)
sys.exit(0)
# Load and merge config if provided
if args.config:
config = load_config_file(args.config)
# Merge config with CLI args (CLI args take precedence)
json_args = argparse.Namespace()
json_args.__dict__ = config
args = parser.parse_args(namespace=json_args)
# Convert dtype string to torch dtype (only if it's still a string)
if isinstance(args.dtype, str):
args.dtype = getattr(torch, args.dtype)
# Remove config and print_config from args before passing to main
args_dict = vars(args)
args_dict.pop("config", None)
args_dict.pop("print_config", None)
main(**args_dict)