Update the sdpa benchmark to measure forward backward time in isolation (#115986)

# Summary

The benchmarks were getting a little stale and I think it makes more sense to measure in isolation now rather than E2E in a mha component.

This is a pre-req for getting the data for https://github.com/pytorch/pytorch/pull/115357

Output from run:
``` Shell
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 23.86634959839284  | 66.21150835417211  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 23.452017060481012 | 66.90612225793302  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 24.478124547749758 |  76.4232068322599  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 |  24.6928428998217  | 75.76151192188263  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 28.69622849393636  | 114.73898496478796 |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 34.399422979913645 | 112.96746158041059 |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 |  65.4690912924707  | 216.26344555988908 |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 88.57532404363155  | 212.07790216431025 |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.582905380055308 | 70.09557797573505  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.068384909071026 | 70.01491216942668  |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 31.671419646590945 | 203.54910241439939 |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 |  33.0585768679157  | 209.45609430782497 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 87.43969700299202  | 469.8729298543185  |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 123.9265550393611  | 580.1084265112877  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 561.1918237991632  | 1181.655174586922  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 884.2707145959139  | 1662.4679416418073 |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115986
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
drisspg
2023-12-18 22:40:43 +00:00
committed by PyTorch MergeBot
parent bf62511e07
commit 6b120c6cf9
2 changed files with 173 additions and 189 deletions

View File

@ -1,189 +0,0 @@
import random
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch.profiler import profile, ProfilerActivity, record_function
class CompositeMHA(torch.nn.Module):
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
super().__init__()
self.in_proj_weight = in_proj_weight
self.in_proj_bias = in_proj_bias
self.out_proj = out_proj
self.num_heads = num_heads
def forward(self, query, key, value, mask):
if not (query is key and key is value):
raise NotImplementedError(
"query, key and value must be the same Tensor for now."
)
if mask is not None:
raise NotImplementedError("mask is currently not supported.")
query_projected = torch.nn.functional.linear(
query, self.in_proj_weight, self.in_proj_bias
)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
attn = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
# Match return signature of nn.MHA
return self.out_proj(attn)
def build_composite_mha_from_nn_mha(pt):
assert pt._qkv_same_embed_dim
in_proj_weight = pt.in_proj_weight
assert in_proj_weight is not None
assert pt.batch_first
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
def forw_back(model, input, upward):
output = model(*input)
output.backward(upward)
# Context manger not working in timer
def forw_back_fused(model, input, upward):
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
output = model(*input)
output.backward(upward)
def forw_back_eager(model, input, upward):
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
output = model(*input)
output.backward(upward)
def run_timing(
min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype
):
dropout_p = 0.0
mask = None
pt = torch.nn.MultiheadAttention(
embed_dim=embed_dimension,
num_heads=num_heads,
batch_first=True,
dropout=dropout_p,
)
npt = pt.cuda().to(dtype)
cpt = build_composite_mha_from_nn_mha(npt)
x = torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device="cuda",
requires_grad=True,
)
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
rand_fused_upward = cpt(x, x, x, mask).clone().detach()
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
rand_eager_upward = cpt(x, x, x, mask).clone().detach()
t0 = benchmark.Timer(
stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)",
globals={
"forw_back_fused": forw_back_fused,
"cpt": cpt,
"x": x,
"rand_fused_upward": rand_fused_upward,
"mask": mask,
},
label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
num_threads=torch.get_num_threads(),
)
t1 = benchmark.Timer(
stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)",
globals={
"forw_back_eager": forw_back_eager,
"cpt": cpt,
"x": x,
"rand_eager_upward": rand_eager_upward,
"mask": mask,
},
label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
num_threads=torch.get_num_threads(),
)
m0 = t0.blocked_autorange(min_run_time=min_run_time)
m1 = t1.blocked_autorange(min_run_time=min_run_time)
print(m0)
print(m1)
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
print("Profile for Fused".center(200, "-"))
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
with profile(
activities=activities, record_shapes=False, with_stack=True
) as prof:
with record_function("Fused SDP forward and backward"):
for _ in range(20):
forw_back(cpt, (x, x, x, mask), rand_fused_upward)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
print("Profile for eager".center(200, "-"))
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
with profile(
activities=activities, record_shapes=False, with_stack=True
) as prof:
with record_function("Fused SDP forward and backward"):
for _ in range(20):
forw_back(cpt, (x, x, x, mask), rand_eager_upward)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
def main():
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
min_run_time = 10
batch_size = 64
num_heads = 32
max_seq_len = 256
embed_dim = 1024
dtype = torch.bfloat16
print(
f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} "
f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}"
)
run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,173 @@
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Callable, List, Tuple
import torch
import torch.utils.benchmark as benchmark
from tabulate import tabulate
from torch.nn.functional import scaled_dot_product_attention
from tqdm import tqdm
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
kv_seq_len: int
embed_dim: int
is_causal: bool
dtype: torch.dtype
device: torch.device = torch.device("cuda")
@property
def head_dim(self) -> int:
return self.embed_dim // self.num_heads
def asdict(self):
dict_obj = asdict(self)
dict_obj["head_dim"] = self.head_dim
return dict_obj
@dataclass(frozen=True)
class ExperimentResults:
forward_time: float
backward_time: float
def asdict(self):
return asdict(self)
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
def asdict(self):
dict1 = asdict(self.config)
dict2 = asdict(self.results)
return {**dict1, **dict2}
def get_input(
config: ExperimentConfig,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = torch.randn(
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
k = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
v = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
return q, k, v
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
q, k, v = get_input(config)
is_causal = config.is_causal
forward_time = benchmark_torch_function_in_microseconds(
scaled_dot_product_attention,
q,
k,
v,
is_causal=is_causal,
attn_mask=None,
)
out_torch = scaled_dot_product_attention(
q, k, v, is_causal=is_causal, attn_mask=None
)
dOut = torch.randn_like(out_torch)
backward_time = benchmark_torch_function_in_microseconds(
out_torch.backward, dOut, retain_graph=True
)
return ExperimentResults(
forward_time=forward_time,
backward_time=backward_time,
)
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [
1,
8,
]
num_heads = [16]
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
embed_dims = [2048]
dtypes = [
torch.bfloat16,
]
is_causal = [True, False]
all_configs = []
for (
bsz,
heads,
(q_seq_len, kv_seq_len),
embed_dim,
causal,
dtype,
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=heads,
q_seq_len=q_seq_len,
kv_seq_len=kv_seq_len,
embed_dim=embed_dim,
is_causal=causal,
dtype=dtype,
)
)
return all_configs
def print_results(experiments: List[Experiment]):
table_data = defaultdict(list)
for experiment in experiments:
for key, value in experiment.asdict().items():
table_data[key].append(value)
del table_data["device"]
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
def main():
seed = 123
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
main()