mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 7763c83af67eebfdd5185dbe6ce15ece2b992a0f. Reverted https://github.com/pytorch/pytorch/pull/127126 on behalf of https://github.com/XuehaiPan due to Broken CI ([comment](https://github.com/pytorch/pytorch/pull/127126#issuecomment-2133044286))
185 lines
4.8 KiB
Python
185 lines
4.8 KiB
Python
import itertools
|
|
from collections import defaultdict
|
|
from contextlib import nullcontext
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Callable, List, Tuple
|
|
|
|
import torch
|
|
import torch.utils.benchmark as benchmark
|
|
from tabulate import tabulate
|
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
from tqdm import tqdm
|
|
|
|
|
|
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
|
# warmup
|
|
for _ in range(5):
|
|
func(*args, **kwargs)
|
|
t0 = benchmark.Timer(
|
|
stmt="func(*args, **kwargs)",
|
|
globals={"args": args, "kwargs": kwargs, "func": func},
|
|
)
|
|
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentConfig:
|
|
batch_size: int
|
|
num_heads: int
|
|
q_seq_len: int
|
|
kv_seq_len: int
|
|
embed_dim: int
|
|
is_causal: bool
|
|
dtype: torch.dtype
|
|
backend: SDPBackend
|
|
device: torch.device = torch.device("cuda")
|
|
|
|
@property
|
|
def head_dim(self) -> int:
|
|
return self.embed_dim // self.num_heads
|
|
|
|
def asdict(self):
|
|
dict_obj = asdict(self)
|
|
dict_obj["head_dim"] = self.head_dim
|
|
return dict_obj
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentResults:
|
|
forward_time: float
|
|
backward_time: float
|
|
|
|
def asdict(self):
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Experiment:
|
|
config: ExperimentConfig
|
|
results: ExperimentResults
|
|
|
|
def asdict(self):
|
|
dict1 = asdict(self.config)
|
|
dict2 = asdict(self.results)
|
|
return {**dict1, **dict2}
|
|
|
|
|
|
def get_input(
|
|
config: ExperimentConfig,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q = torch.randn(
|
|
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
|
|
dtype=config.dtype,
|
|
device=config.device,
|
|
requires_grad=True,
|
|
)
|
|
k = torch.randn(
|
|
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
|
|
dtype=config.dtype,
|
|
device=config.device,
|
|
requires_grad=True,
|
|
)
|
|
v = torch.randn(
|
|
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
|
|
dtype=config.dtype,
|
|
device=config.device,
|
|
requires_grad=True,
|
|
)
|
|
return q, k, v
|
|
|
|
|
|
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
|
q, k, v = get_input(config)
|
|
is_causal = config.is_causal
|
|
context = (
|
|
sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
|
|
)
|
|
with context:
|
|
forward_time = benchmark_torch_function_in_microseconds(
|
|
scaled_dot_product_attention,
|
|
q,
|
|
k,
|
|
v,
|
|
is_causal=is_causal,
|
|
attn_mask=None,
|
|
)
|
|
out_torch = scaled_dot_product_attention(
|
|
q, k, v, is_causal=is_causal, attn_mask=None
|
|
)
|
|
dOut = torch.randn_like(out_torch)
|
|
backward_time = benchmark_torch_function_in_microseconds(
|
|
out_torch.backward, dOut, retain_graph=True
|
|
)
|
|
|
|
return ExperimentResults(
|
|
forward_time=forward_time,
|
|
backward_time=backward_time,
|
|
)
|
|
|
|
|
|
def generate_experiment_configs() -> List[ExperimentConfig]:
|
|
batch_sizes = [
|
|
1,
|
|
8,
|
|
]
|
|
num_heads = [16]
|
|
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
|
|
embed_dims = [2048]
|
|
backends = [None] # If set to None, all backends are enabled
|
|
dtypes = [
|
|
torch.bfloat16,
|
|
]
|
|
is_causal = [True, False]
|
|
all_configs = []
|
|
for (
|
|
bsz,
|
|
heads,
|
|
(q_seq_len, kv_seq_len),
|
|
embed_dim,
|
|
causal,
|
|
dtype,
|
|
backend,
|
|
) in itertools.product(
|
|
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends
|
|
):
|
|
all_configs.append(
|
|
ExperimentConfig(
|
|
batch_size=bsz,
|
|
num_heads=heads,
|
|
q_seq_len=q_seq_len,
|
|
kv_seq_len=kv_seq_len,
|
|
embed_dim=embed_dim,
|
|
is_causal=causal,
|
|
dtype=dtype,
|
|
backend=backend,
|
|
)
|
|
)
|
|
|
|
return all_configs
|
|
|
|
|
|
def print_results(experiments: List[Experiment]):
|
|
table_data = defaultdict(list)
|
|
for experiment in experiments:
|
|
for key, value in experiment.asdict().items():
|
|
table_data[key].append(value)
|
|
del table_data["device"]
|
|
if table_data["backend"][0] is None:
|
|
del table_data["backend"]
|
|
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
|
|
|
|
|
|
def main():
|
|
seed = 123
|
|
torch.manual_seed(seed)
|
|
results = []
|
|
for config in tqdm(generate_experiment_configs()):
|
|
results.append(Experiment(config, run_single_experiment(config)))
|
|
|
|
print_results(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|