mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update the sdpa benchmark to measure forward backward time in isolation (#115986)
# Summary The benchmarks were getting a little stale and I think it makes more sense to measure in isolation now rather than E2E in a mha component. This is a pre-req for getting the data for https://github.com/pytorch/pytorch/pull/115357 Output from run: ``` Shell +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 23.86634959839284 | 66.21150835417211 | | 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 23.452017060481012 | 66.90612225793302 | | 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 24.478124547749758 | 76.4232068322599 | | 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 24.6928428998217 | 75.76151192188263 | | 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 28.69622849393636 | 114.73898496478796 | | 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 34.399422979913645 | 112.96746158041059 | | 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 65.4690912924707 | 216.26344555988908 | | 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 88.57532404363155 | 212.07790216431025 | | 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.582905380055308 | 70.09557797573505 | | 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.068384909071026 | 70.01491216942668 | | 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 31.671419646590945 | 203.54910241439939 | | 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 33.0585768679157 | 209.45609430782497 | | 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 87.43969700299202 | 469.8729298543185 | | 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 123.9265550393611 | 580.1084265112877 | | 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 561.1918237991632 | 1181.655174586922 | | 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 884.2707145959139 | 1662.4679416418073 | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/115986 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf62511e07
commit
6b120c6cf9
@ -1,189 +0,0 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch.profiler import profile, ProfilerActivity, record_function
|
||||
|
||||
|
||||
class CompositeMHA(torch.nn.Module):
|
||||
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
|
||||
super().__init__()
|
||||
self.in_proj_weight = in_proj_weight
|
||||
self.in_proj_bias = in_proj_bias
|
||||
self.out_proj = out_proj
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
if not (query is key and key is value):
|
||||
raise NotImplementedError(
|
||||
"query, key and value must be the same Tensor for now."
|
||||
)
|
||||
if mask is not None:
|
||||
raise NotImplementedError("mask is currently not supported.")
|
||||
|
||||
query_projected = torch.nn.functional.linear(
|
||||
query, self.in_proj_weight, self.in_proj_bias
|
||||
)
|
||||
|
||||
batch_size = query_projected.size(0)
|
||||
embed_dim = query_projected.size(2)
|
||||
head_dim = embed_dim // (self.num_heads * 3)
|
||||
|
||||
query, key, value = query_projected.chunk(3, -1)
|
||||
|
||||
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
attn = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
|
||||
# Match return signature of nn.MHA
|
||||
return self.out_proj(attn)
|
||||
|
||||
|
||||
def build_composite_mha_from_nn_mha(pt):
|
||||
assert pt._qkv_same_embed_dim
|
||||
in_proj_weight = pt.in_proj_weight
|
||||
assert in_proj_weight is not None
|
||||
assert pt.batch_first
|
||||
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
|
||||
|
||||
|
||||
def forw_back(model, input, upward):
|
||||
output = model(*input)
|
||||
output.backward(upward)
|
||||
|
||||
|
||||
# Context manger not working in timer
|
||||
|
||||
|
||||
def forw_back_fused(model, input, upward):
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
|
||||
output = model(*input)
|
||||
output.backward(upward)
|
||||
|
||||
|
||||
def forw_back_eager(model, input, upward):
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
|
||||
output = model(*input)
|
||||
output.backward(upward)
|
||||
|
||||
|
||||
def run_timing(
|
||||
min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype
|
||||
):
|
||||
dropout_p = 0.0
|
||||
mask = None
|
||||
|
||||
pt = torch.nn.MultiheadAttention(
|
||||
embed_dim=embed_dimension,
|
||||
num_heads=num_heads,
|
||||
batch_first=True,
|
||||
dropout=dropout_p,
|
||||
)
|
||||
npt = pt.cuda().to(dtype)
|
||||
cpt = build_composite_mha_from_nn_mha(npt)
|
||||
x = torch.randn(
|
||||
batch_size,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
|
||||
rand_fused_upward = cpt(x, x, x, mask).clone().detach()
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
|
||||
rand_eager_upward = cpt(x, x, x, mask).clone().detach()
|
||||
|
||||
t0 = benchmark.Timer(
|
||||
stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)",
|
||||
globals={
|
||||
"forw_back_fused": forw_back_fused,
|
||||
"cpt": cpt,
|
||||
"x": x,
|
||||
"rand_fused_upward": rand_fused_upward,
|
||||
"mask": mask,
|
||||
},
|
||||
label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
|
||||
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
|
||||
t1 = benchmark.Timer(
|
||||
stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)",
|
||||
globals={
|
||||
"forw_back_eager": forw_back_eager,
|
||||
"cpt": cpt,
|
||||
"x": x,
|
||||
"rand_eager_upward": rand_eager_upward,
|
||||
"mask": mask,
|
||||
},
|
||||
label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
|
||||
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
|
||||
m0 = t0.blocked_autorange(min_run_time=min_run_time)
|
||||
m1 = t1.blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
print(m0)
|
||||
print(m1)
|
||||
|
||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||
|
||||
print("Profile for Fused".center(200, "-"))
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, with_stack=True
|
||||
) as prof:
|
||||
with record_function("Fused SDP forward and backward"):
|
||||
for _ in range(20):
|
||||
forw_back(cpt, (x, x, x, mask), rand_fused_upward)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
print("Profile for eager".center(200, "-"))
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, with_stack=True
|
||||
) as prof:
|
||||
with record_function("Fused SDP forward and backward"):
|
||||
for _ in range(20):
|
||||
forw_back(cpt, (x, x, x, mask), rand_eager_upward)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
|
||||
def main():
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
min_run_time = 10
|
||||
batch_size = 64
|
||||
num_heads = 32
|
||||
max_seq_len = 256
|
||||
embed_dim = 1024
|
||||
dtype = torch.bfloat16
|
||||
|
||||
print(
|
||||
f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} "
|
||||
f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}"
|
||||
)
|
||||
run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
173
benchmarks/transformer/sdpa.py
Normal file
173
benchmarks/transformer/sdpa.py
Normal file
@ -0,0 +1,173 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from tabulate import tabulate
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
func(*args, **kwargs)
|
||||
t0 = benchmark.Timer(
|
||||
stmt="func(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "func": func},
|
||||
)
|
||||
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentConfig:
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
q_seq_len: int
|
||||
kv_seq_len: int
|
||||
embed_dim: int
|
||||
is_causal: bool
|
||||
dtype: torch.dtype
|
||||
device: torch.device = torch.device("cuda")
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
return self.embed_dim // self.num_heads
|
||||
|
||||
def asdict(self):
|
||||
dict_obj = asdict(self)
|
||||
dict_obj["head_dim"] = self.head_dim
|
||||
return dict_obj
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentResults:
|
||||
forward_time: float
|
||||
backward_time: float
|
||||
|
||||
def asdict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def asdict(self):
|
||||
dict1 = asdict(self.config)
|
||||
dict2 = asdict(self.results)
|
||||
return {**dict1, **dict2}
|
||||
|
||||
|
||||
def get_input(
|
||||
config: ExperimentConfig,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = torch.randn(
|
||||
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
|
||||
dtype=config.dtype,
|
||||
device=config.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
k = torch.randn(
|
||||
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
|
||||
dtype=config.dtype,
|
||||
device=config.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
v = torch.randn(
|
||||
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
|
||||
dtype=config.dtype,
|
||||
device=config.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
q, k, v = get_input(config)
|
||||
is_causal = config.is_causal
|
||||
|
||||
forward_time = benchmark_torch_function_in_microseconds(
|
||||
scaled_dot_product_attention,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
is_causal=is_causal,
|
||||
attn_mask=None,
|
||||
)
|
||||
out_torch = scaled_dot_product_attention(
|
||||
q, k, v, is_causal=is_causal, attn_mask=None
|
||||
)
|
||||
dOut = torch.randn_like(out_torch)
|
||||
backward_time = benchmark_torch_function_in_microseconds(
|
||||
out_torch.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
forward_time=forward_time,
|
||||
backward_time=backward_time,
|
||||
)
|
||||
|
||||
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
batch_sizes = [
|
||||
1,
|
||||
8,
|
||||
]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
|
||||
embed_dims = [2048]
|
||||
dtypes = [
|
||||
torch.bfloat16,
|
||||
]
|
||||
is_causal = [True, False]
|
||||
all_configs = []
|
||||
for (
|
||||
bsz,
|
||||
heads,
|
||||
(q_seq_len, kv_seq_len),
|
||||
embed_dim,
|
||||
causal,
|
||||
dtype,
|
||||
) in itertools.product(
|
||||
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes
|
||||
):
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
batch_size=bsz,
|
||||
num_heads=heads,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len,
|
||||
embed_dim=embed_dim,
|
||||
is_causal=causal,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def print_results(experiments: List[Experiment]):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in experiments:
|
||||
for key, value in experiment.asdict().items():
|
||||
table_data[key].append(value)
|
||||
del table_data["device"]
|
||||
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
|
||||
|
||||
|
||||
def main():
|
||||
seed = 123
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(generate_experiment_configs()):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user