mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129754 Approved by: https://github.com/ezyang
363 lines
11 KiB
Python
363 lines
11 KiB
Python
import argparse
|
|
import itertools
|
|
import random
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from pprint import pprint
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
from prettytable import PrettyTable
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.utils.benchmark as benchmark
|
|
from torch.backends.cuda import sdp_kernel
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentConfig:
|
|
batch_size: int
|
|
num_heads: int
|
|
max_sequence_len: int
|
|
embed_dimension: int
|
|
dtype: torch.dtype
|
|
pad_percentage: Optional[float]
|
|
enable_math: bool
|
|
enable_flash: bool
|
|
enable_mem_efficient: bool
|
|
enable_cudnn: bool
|
|
|
|
def get_entries(self) -> List:
|
|
return [
|
|
self.batch_size,
|
|
self.num_heads,
|
|
self.max_sequence_len,
|
|
self.embed_dimension,
|
|
self.dtype,
|
|
self.pad_percentage,
|
|
self.enable_math,
|
|
self.enable_flash,
|
|
self.enable_mem_efficient,
|
|
self.enable_cudnn,
|
|
]
|
|
|
|
@classmethod
|
|
def get_entry_names(cls) -> List[str]:
|
|
return [
|
|
"batch_size",
|
|
"num_heads",
|
|
"max_sequence_len",
|
|
"embed_dimension",
|
|
"dtype",
|
|
"pad_percentage",
|
|
"enable_math",
|
|
"enable_flash",
|
|
"enable_mem_efficient",
|
|
"enable_cudnn",
|
|
]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentResults:
|
|
nn_mha_time: float
|
|
compiled_nn_mha_time: Optional[float]
|
|
composite_mha_time: float
|
|
compiled_composite_mha_time: Optional[float]
|
|
|
|
def get_entries(self) -> List:
|
|
return [
|
|
f"{self.nn_mha_time:2f}",
|
|
f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None,
|
|
f"{self.composite_mha_time:2f}",
|
|
f"{self.compiled_composite_mha_time:2f}"
|
|
if self.compiled_composite_mha_time
|
|
else None,
|
|
]
|
|
|
|
@classmethod
|
|
def get_entry_names(cls) -> List[str]:
|
|
return [
|
|
"nn_mha_time (\u00B5s)",
|
|
"compiled_nn_mha_time (\u00B5s)",
|
|
"composite_mha_time (\u00B5s)",
|
|
"compiled_composite_mha_time (\u00B5s)",
|
|
]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Experiment:
|
|
config: ExperimentConfig
|
|
results: ExperimentResults
|
|
|
|
def get_entries(self) -> List:
|
|
return self.config.get_entries() + self.results.get_entries()
|
|
|
|
|
|
class CompositeMHA(torch.nn.Module):
|
|
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
|
|
super().__init__()
|
|
self.in_proj_weight = in_proj_weight
|
|
self.in_proj_bias = in_proj_bias
|
|
self.out_proj = out_proj
|
|
self.num_heads = num_heads
|
|
|
|
def forward(self, query, key, value, mask):
|
|
if not (query is key and key is value):
|
|
raise NotImplementedError(
|
|
"query, key and value must be the same Tensor for now."
|
|
)
|
|
if mask is not None:
|
|
raise NotImplementedError("mask is currently not supported.")
|
|
|
|
query_projected = torch.nn.functional.linear(
|
|
query, self.in_proj_weight, self.in_proj_bias
|
|
)
|
|
|
|
batch_size = query_projected.size(0)
|
|
embed_dim = query_projected.size(2)
|
|
head_dim = embed_dim // (self.num_heads * 3)
|
|
|
|
query, key, value = query_projected.chunk(3, -1)
|
|
|
|
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
attn = torch.nn.functional.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=None,
|
|
dropout_p=0.0,
|
|
is_causal=False,
|
|
)
|
|
|
|
attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
|
|
# Match return signature of nn.MHA
|
|
return self.out_proj(attn), None
|
|
|
|
|
|
def build_composite_mha_from_nn_mha(pt):
|
|
assert pt._qkv_same_embed_dim
|
|
in_proj_weight = pt.in_proj_weight
|
|
assert in_proj_weight is not None
|
|
assert pt.batch_first
|
|
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
|
|
|
|
|
|
def generate_rand_batch(
|
|
batch_size,
|
|
max_sequence_len,
|
|
embed_dimension,
|
|
pad_percentage=None,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
):
|
|
if not pad_percentage:
|
|
return (
|
|
torch.randn(
|
|
batch_size,
|
|
max_sequence_len,
|
|
embed_dimension,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
None,
|
|
)
|
|
# Really slow but should work
|
|
seq_len_list = [
|
|
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
|
|
for _ in range(batch_size)
|
|
]
|
|
# Make random ele max length
|
|
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
|
|
# print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}")
|
|
return (
|
|
torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(seq_len, embed_dimension, dtype=dtype, device=device)
|
|
for seq_len in seq_len_list
|
|
]
|
|
),
|
|
seq_len_list,
|
|
)
|
|
|
|
|
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
|
t0 = benchmark.Timer(
|
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
|
)
|
|
return t0.blocked_autorange().mean * 1e6
|
|
|
|
|
|
def assert_close_tensors(tensor_a, tensor_b):
|
|
# First order sanity check. Not a replacement for rigorous tests.
|
|
if tensor_a.is_nested and tensor_b.is_nested:
|
|
for a, b in zip(tensor_a.unbind(), tensor_b.unbind()):
|
|
assert torch.allclose(a, b, atol=1e-2, rtol=1e-2)
|
|
else:
|
|
assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
|
with sdp_kernel(
|
|
enable_math=config.enable_math,
|
|
enable_flash=config.enable_flash,
|
|
enable_mem_efficient=config.enable_mem_efficient,
|
|
enable_cudnn=config.enable_cudnn,
|
|
) as kernel_choice, torch.inference_mode() as inference_mode:
|
|
dropout_p = 0.0
|
|
mask = None
|
|
|
|
nn_mha = torch.nn.MultiheadAttention(
|
|
embed_dim=config.embed_dimension,
|
|
num_heads=config.num_heads,
|
|
batch_first=True,
|
|
dropout=dropout_p,
|
|
)
|
|
nn_mha = nn_mha.eval().to("cuda", config.dtype)
|
|
composite_mha = build_composite_mha_from_nn_mha(nn_mha)
|
|
qkv, lengths = generate_rand_batch(
|
|
config.batch_size,
|
|
config.max_sequence_len,
|
|
config.embed_dimension,
|
|
config.pad_percentage,
|
|
config.dtype,
|
|
)
|
|
nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask)
|
|
composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask)
|
|
|
|
# First order sanity check
|
|
assert_close_tensors(nn_mha_output, composite_mha_output)
|
|
|
|
nn_mha_time = benchmark_torch_function_in_microseconds(
|
|
nn_mha, qkv, qkv, qkv, mask
|
|
)
|
|
composite_mha_time = benchmark_torch_function_in_microseconds(
|
|
composite_mha, qkv, qkv, qkv, mask
|
|
)
|
|
|
|
# TorchDynamo will error on NestedTensors
|
|
if config.pad_percentage is None:
|
|
compiled_nn_mha = torch.compile(nn_mha)
|
|
compiled_composite_mha = torch.compile(composite_mha)
|
|
|
|
compiled_nn_mha_time = benchmark_torch_function_in_microseconds(
|
|
compiled_nn_mha, qkv, qkv, qkv, mask
|
|
)
|
|
|
|
compiled_composite_mha_time = benchmark_torch_function_in_microseconds(
|
|
compiled_composite_mha,
|
|
qkv,
|
|
qkv,
|
|
qkv,
|
|
mask,
|
|
)
|
|
else:
|
|
compiled_nn_mha_time = None
|
|
compiled_composite_mha_time = None
|
|
|
|
results = ExperimentResults(
|
|
nn_mha_time,
|
|
compiled_nn_mha_time,
|
|
composite_mha_time,
|
|
compiled_composite_mha_time,
|
|
)
|
|
return Experiment(config, results)
|
|
|
|
|
|
# Could return generator
|
|
def generate_experiments(
|
|
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
|
|
) -> List[ExperimentConfig]:
|
|
configs = []
|
|
for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product(
|
|
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
|
|
):
|
|
configs.append(
|
|
ExperimentConfig(
|
|
batch_size=bsz,
|
|
num_heads=n_heads,
|
|
max_sequence_len=seq_len,
|
|
embed_dimension=embed_dim,
|
|
dtype=dtype,
|
|
pad_percentage=padding,
|
|
enable_math=False,
|
|
enable_flash=True,
|
|
enable_mem_efficient=True,
|
|
enable_cudnn=True,
|
|
)
|
|
)
|
|
return configs
|
|
|
|
|
|
def main(save_path: Optional[Path]):
|
|
seed = 123
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
# Run one timing experiment comparing nn_mha vs composite_mha
|
|
config = ExperimentConfig(
|
|
batch_size=128,
|
|
num_heads=8,
|
|
max_sequence_len=512,
|
|
embed_dimension=512,
|
|
dtype=torch.float16,
|
|
pad_percentage=None,
|
|
enable_math=False,
|
|
enable_flash=True,
|
|
enable_mem_efficient=True,
|
|
enable_cudnn=True,
|
|
)
|
|
|
|
experiment = run_single_experiment(config)
|
|
pprint(experiment)
|
|
|
|
table = PrettyTable()
|
|
table.float_format = ".3"
|
|
table.field_names = (
|
|
ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names()
|
|
)
|
|
|
|
# Run a bunch of experiments
|
|
batch_sizes = [256]
|
|
num_heads = [32]
|
|
max_seq_lens = [256]
|
|
embed_dims = [512]
|
|
dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
|
pad_percentages = [None, 0.9]
|
|
|
|
experiment_configs = generate_experiments(
|
|
batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
|
|
)
|
|
|
|
experiments: List[Experiment] = []
|
|
for experiment_config in tqdm(experiment_configs):
|
|
experiment = run_single_experiment(experiment_config)
|
|
experiments.append(experiment)
|
|
table.add_row(experiment.get_entries())
|
|
|
|
print(table)
|
|
|
|
csv_string = table.get_csv_string()
|
|
if save_path is not None:
|
|
with open(save_path, "w") as csvfile:
|
|
csvfile.write(csv_string)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--save-path", "--save_path", type=str, help="Path to save the results"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
save_path = Path(args.save_path) if args.save_path else None
|
|
main(save_path)
|