Add flex decoding benchmark (#130850)

ghstack-source-id: b4f26fb66ed47907b11580c8c853737959c58811
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130788

Add benchmark for flex decoding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130850
Approved by: https://github.com/Chillee, https://github.com/drisspg
This commit is contained in:
joydddd
2024-07-18 18:09:25 +00:00
committed by PyTorch MergeBot
parent fff92d4f18
commit 6d9f74f0af

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch._dynamo.config.automatic_dynamic_shapes = False
@ -35,15 +35,20 @@ class ExperimentConfig:
score_mod: Callable
dtype: torch.dtype
calculate_bwd_time: bool
cal_bandwidth: bool
def __post_init__(self):
assert len(self.shape) == 4, "Shape must be of length 4"
assert (
len(self.shape) == 6
), "Shape must be of length 6" # [B, Hq, M, Hkv, N, D]
def asdict(self):
# Convert the dataclass instance to a dictionary
d = asdict(self)
# Remove the 'calculate_bwd_time' key
# Remove the 'calculate_bwd_time' and `cal_bandwidth` key
d.pop("calculate_bwd_time", None)
d.pop("cal_bandwidth", None)
d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape")
return d
@ -72,16 +77,21 @@ class Experiment:
def generate_inputs(
batch_size: int,
num_heads: int,
q_heads: int,
q_sequence_length: int,
kv_heads: int,
kv_sequence_length: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
):
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
q_shape = (batch_size, q_sequence_length, q_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, kv_heads * head_dim)
assert q_heads % kv_heads == 0
num_h_groups = q_heads // kv_heads
make_q = partial(
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
@ -91,32 +101,33 @@ def generate_inputs(
)
query = (
make_q()
.view(batch_size, q_sequence_length, num_heads, head_dim)
.view(batch_size, num_h_groups * q_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
key = (
make_kv()
.view(batch_size, kv_sequence_length, num_heads, head_dim)
.view(batch_size, kv_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
value = (
make_kv()
.view(batch_size, kv_sequence_length, num_heads, head_dim)
.view(batch_size, kv_sequence_length, kv_heads, head_dim)
.transpose(1, 2)
)
return query, key, value
def run_single_experiment(
config: ExperimentConfig, dynamic=False, max_autotune=False
config: ExperimentConfig, dynamic=False, max_autotune=False, enable_mask=False
) -> ExperimentResults:
device = torch.device("cuda")
batch_size, num_heads, q_seq_len, head_dim = config.shape
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
batch_size,
num_heads,
q_seq_len,
q_heads,
q_seq_len,
kv_heads,
kv_seq_len,
head_dim,
config.dtype,
device,
@ -135,11 +146,18 @@ def run_single_experiment(
score_mod = config.score_mod
if enable_mask:
block_mask = create_block_mask(
score_mod, 1, 1, q_seq_len * (q_heads // kv_heads), kv_seq_len, query.device
)
else:
block_mask = None
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
)
forward_compiled_time = benchmark_torch_function_in_microseconds(
compiled_sdpa, query, key, value, score_mod
compiled_sdpa, query, key, value, score_mod, block_mask
)
if config.calculate_bwd_time:
@ -176,10 +194,54 @@ def calculate_speedup(results: ExperimentResults, type: str) -> float:
raise ValueError(f"Invalid type {type}")
def calculate_bandwidth(
config: ExperimentConfig, results: ExperimentResults, type: str
) -> float:
if type == "fwd":
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
query_size = (
batch_size
* q_heads
* q_seq_len
* head_dim
* torch.finfo(config.dtype).bits
/ 8
)
kv_size = (
batch_size
* kv_heads
* kv_seq_len
* head_dim
* torch.finfo(config.dtype).bits
/ 8
* 2
)
output_size = query_size
total_size = (query_size + kv_size + output_size) / 1e9 # In GB
time_in_seconds = results.fwd_times.compiled_time / 1e6
return total_size / time_in_seconds / 1e3
else:
raise ValueError(f"Invalid type {type}")
def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> float:
(B, Hq, M, Hkv, N, D) = config.shape
qk_flops = M * N * D * 2
softmax_flops = M * N * 2 # Not counting online softmax overhead
o_flops = M * D * N * 2
# Not counting split k overhead
total_flops = B * Hq * (qk_flops + softmax_flops + o_flops)
return total_flops / results.fwd_times.compiled_time / 1e6 # in TFLOPs/
def get_func_name(func):
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
def set_func_name(func, name):
func.__name__ = name
def get_average_speedups(results: List[Experiment], type: str):
# Calculate speedups
speedups = [calculate_speedup(r.results, type) for r in results]
@ -231,6 +293,16 @@ def print_results(results: List[Experiment]):
# Calculate speedups
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
table_data["fwd_speedup"] = fwd_speedups
# Calculate mem + computational throughput
if results[0].config.cal_bandwidth:
fwd_bandwidth = [
calculate_bandwidth(r.config, r.results, type="fwd") for r in results
]
table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth
fwd_tflops = [calculate_tflops(r.config, r.results) for r in results]
table_data["TFlops/s"] = fwd_tflops
if results[0].config.calculate_bwd_time:
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
table_data["bwd_speedup"] = bwd_speedups
@ -274,36 +346,74 @@ def generate_score_mods(score_mods: List[str]) -> List[Callable]:
return [function_dict[name] for name in score_mods]
def get_gqa_score_mod(score_mod, G, q_seq_len):
def score_mod_gqa(score, b, hkv, m, n):
g = m // q_seq_len
new_m = m % q_seq_len
hq = hkv * G + g
return score_mod(score, b, hq, new_m, n)
score_mod_name = get_func_name(score_mod)
set_func_name(score_mod_gqa, score_mod_name + "_gqa")
return score_mod_gqa
def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: List[int],
num_heads: List[int],
num_heads: List[Tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods: List[str],
decoding: bool,
kv_cache_size: List[int],
cal_bandwidth: bool,
) -> List[ExperimentConfig]:
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
assert not (calculate_bwd and decoding), "Decoding does not support backward"
if decoding:
q_kv_seq_lens = [(1, i) for i in seq_lens] # only testing query length == 1
else:
q_kv_seq_lens = [(i, i) for i in seq_lens] # only testing q_len == kv_len
dtypes = [dtype]
score_mods = generate_score_mods(score_mods)
all_configs = []
for (
bsz,
n_heads,
(q_heads, kv_heads),
(q_seq_len, kv_seq_len),
head_dim,
score_mod,
dtype,
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
kv_cache_size if kv_cache_size else batch_sizes,
num_heads,
q_kv_seq_lens,
head_dims,
score_mods,
dtypes,
):
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
if kv_cache_size:
head_size_bytes = torch.finfo(dtype).bits / 8 * head_dim
bsz = int(
(bsz * 1024 * 1024) // (kv_heads * kv_seq_len * head_size_bytes * 2)
)
if bsz <= 0:
continue
if q_heads != kv_heads: # GQA work around before it's explicitly supported
assert q_heads % kv_heads == 0
G = q_heads // kv_heads
score_mod = get_gqa_score_mod(score_mod, G, q_seq_len)
all_configs.append(
ExperimentConfig(
shape=(bsz, n_heads, q_seq_len, head_dim),
shape=(bsz, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim),
score_mod=score_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
cal_bandwidth=cal_bandwidth,
)
)
@ -317,14 +427,26 @@ def main(args):
results = []
for config in tqdm(
generate_experiment_configs(
args.calculate_bwd, args.dtype, args.b, args.nh, args.s, args.d, args.mods
args.calculate_bwd,
args.dtype,
args.b,
args.nh,
args.s,
args.d,
args.mods,
args.decoding,
args.kv_cache_size,
args.cal_bandwidth,
)
):
results.append(
Experiment(
config,
run_single_experiment(
config, dynamic=args.dynamic, max_autotune=args.max_autotune
config,
dynamic=args.dynamic,
max_autotune=args.max_autotune,
enable_mask=args.mask,
),
)
)
@ -332,6 +454,14 @@ def main(args):
print_results(results)
def heads_input_type(s):
try:
hq, hkv = map(int, s.split(","))
return hq, hkv
except Exception as e:
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
if __name__ == "__main__":
# Set up the argument parser
parser = argparse.ArgumentParser(
@ -351,7 +481,13 @@ if __name__ == "__main__":
parser.add_argument(
"-b", type=int, nargs="+", help="batch sizes", default=[2, 8, 16]
)
parser.add_argument("-nh", type=int, nargs="+", help="# of heads", default=[16])
parser.add_argument(
"-nh",
type=heads_input_type,
nargs="+",
help="# of q-heads,kv-heads",
default=[(16, 16), (16, 2)],
)
parser.add_argument(
"-s", type=int, nargs="+", help="sequence lengths", default=[512, 1024, 4096]
)
@ -366,6 +502,29 @@ if __name__ == "__main__":
parser.add_argument(
"--max-autotune", action="store_true", help="Turn on max-autotune"
)
parser.add_argument(
"--decoding",
action="store_true",
help="Benchmark Decoding (query sequence length = 1)",
)
parser.add_argument(
"--kv-cache-size",
type=int,
nargs="+",
required=False,
help="""
key/value cache size in MiB.
Ignores -b batch size and calculate batch size from kv_cache size instead when specified.
""",
)
parser.add_argument(
"--cal-bandwidth",
action="store_true",
help="Calculate kernel memory bandwidth & computational throughput. ",
)
parser.add_argument(
"--mask", action="store_true", help="Enables block sparsity mask. "
)
# Parse arguments
args = parser.parse_args()