mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515 Approved by: https://github.com/Lucaskabela
251 lines
7.4 KiB
Python
251 lines
7.4 KiB
Python
import itertools
|
|
from collections.abc import Callable
|
|
from dataclasses import asdict, dataclass
|
|
from functools import partial
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
from tabulate import tabulate
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.benchmark as benchmark
|
|
from torch.nn.attention.bias import CausalBias, CausalVariant
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
|
# warmup
|
|
for _ in range(5):
|
|
func(*args, **kwargs)
|
|
t0 = benchmark.Timer(
|
|
stmt="func(*args, **kwargs)",
|
|
globals={"args": args, "kwargs": kwargs, "func": func},
|
|
)
|
|
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentConfig:
|
|
batch_size: int
|
|
num_heads: int
|
|
q_seq_len: int
|
|
k_seq_len: int
|
|
embed_dim: int
|
|
dtype: torch.dtype
|
|
|
|
@property
|
|
def head_dim(self) -> int:
|
|
return self.embed_dim // self.num_heads
|
|
|
|
def asdict(self):
|
|
dict_obj = asdict(self)
|
|
dict_obj["head_dim"] = self.head_dim
|
|
return dict_obj
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExperimentResults:
|
|
materialized_mask_time: float
|
|
attn_mask_subclass_time: float
|
|
|
|
def get_entries(self) -> list:
|
|
return [
|
|
f"{self.materialized_mask_time:2f}",
|
|
f"{self.attn_mask_subclass_time:2f}",
|
|
]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Experiment:
|
|
config: ExperimentConfig
|
|
results: ExperimentResults
|
|
|
|
def get_entries(self) -> list:
|
|
return self.config.get_entries() + self.results.get_entries()
|
|
|
|
|
|
def generate_inputs(
|
|
batch_size, q_sequence_length, kv_sequence_length, embed_dim, dtype, device
|
|
):
|
|
q_shape = (batch_size, q_sequence_length, embed_dim)
|
|
kv_shape = (batch_size, kv_sequence_length, embed_dim)
|
|
|
|
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
|
|
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
|
return make_q(), make_kv(), make_kv()
|
|
|
|
|
|
class CompositeMHA(torch.nn.Module):
|
|
def __init__(self, num_heads, embed_dim, device=None, dtype=None):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
|
|
self.head_dim = embed_dim // num_heads
|
|
self.embed_dim = embed_dim
|
|
assert self.head_dim * num_heads == self.embed_dim, (
|
|
"embed_dim must be divisible by num_heads"
|
|
)
|
|
|
|
self.q_proj_weight = Parameter(
|
|
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
|
)
|
|
self.k_proj_weight = Parameter(
|
|
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
|
)
|
|
self.v_proj_weight = Parameter(
|
|
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
|
)
|
|
self.out_proj = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
|
|
self.num_heads = num_heads
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
mask: Union[torch.Tensor, CausalBias],
|
|
):
|
|
query_projected = F.linear(query, self.q_proj_weight)
|
|
key_projected = F.linear(key, self.k_proj_weight)
|
|
value_projected = F.linear(value, self.v_proj_weight)
|
|
|
|
query = query.view(
|
|
query_projected.size(0), -1, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
key = key.view(
|
|
key_projected.size(0), -1, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
value = value.view(
|
|
value_projected.size(0), -1, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
|
|
attn = torch.nn.functional.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=mask,
|
|
dropout_p=0.0,
|
|
)
|
|
|
|
attn = attn.transpose(1, 2).reshape(query.size(0), -1, self.embed_dim)
|
|
# Match return signature of nn.MHA
|
|
return F.linear(attn, self.out_proj)
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.q_proj_weight)
|
|
nn.init.xavier_uniform_(self.k_proj_weight)
|
|
nn.init.xavier_uniform_(self.v_proj_weight)
|
|
nn.init.constant_(self.out_proj, 0.0)
|
|
|
|
|
|
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
|
device = torch.device("cuda")
|
|
composite_mha = CompositeMHA(
|
|
config.num_heads, config.embed_dim, device, config.dtype
|
|
)
|
|
composite_mha.reset_parameters()
|
|
query, key, value = generate_inputs(
|
|
config.batch_size,
|
|
config.q_seq_len,
|
|
config.k_seq_len,
|
|
config.embed_dim,
|
|
config.dtype,
|
|
device,
|
|
)
|
|
attn_mask = CausalBias(
|
|
CausalVariant.LOWER_RIGHT, config.q_seq_len, config.k_seq_len
|
|
)
|
|
attn_mask_tensor = attn_mask._materialize(device)
|
|
|
|
materialized_mask_time = benchmark_torch_function_in_microseconds(
|
|
composite_mha, query, key, value, attn_mask_tensor
|
|
)
|
|
attn_mask_subclass_time = benchmark_torch_function_in_microseconds(
|
|
composite_mha, query, key, value, attn_mask
|
|
)
|
|
torch.testing.assert_close(
|
|
composite_mha(query, key, value, attn_mask_tensor),
|
|
composite_mha(query, key, value, attn_mask),
|
|
)
|
|
|
|
return ExperimentResults(
|
|
materialized_mask_time=materialized_mask_time,
|
|
attn_mask_subclass_time=attn_mask_subclass_time,
|
|
)
|
|
|
|
|
|
def generate_experiment_configs() -> list[ExperimentConfig]:
|
|
batch_sizes = [1, 8, 16, 128]
|
|
num_heads = [16, 32]
|
|
q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
|
|
embed_dims = [2048, 4096]
|
|
dtypes = [
|
|
torch.bfloat16,
|
|
]
|
|
all_configs = []
|
|
for bsz, heads, (q_seq_len, kv_seq_len), embed_dim, dtype in itertools.product(
|
|
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, dtypes
|
|
):
|
|
all_configs.append(
|
|
ExperimentConfig(
|
|
batch_size=bsz,
|
|
num_heads=heads,
|
|
q_seq_len=q_seq_len,
|
|
k_seq_len=kv_seq_len,
|
|
embed_dim=embed_dim,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
|
|
return all_configs
|
|
|
|
|
|
def calculate_speedup(results: ExperimentResults) -> float:
|
|
return results.materialized_mask_time / results.attn_mask_subclass_time
|
|
|
|
|
|
def print_results(results: list[Experiment]):
|
|
# Calculate speedups
|
|
speedups = [calculate_speedup(r.results) for r in results]
|
|
|
|
# Find indices of max and min speedups
|
|
max_speedup_index = np.argmax(speedups)
|
|
min_speedup_index = np.argmin(speedups)
|
|
|
|
# Get the config dictionaries
|
|
max_config_dict = results[max_speedup_index].config.asdict()
|
|
min_config_dict = results[min_speedup_index].config.asdict()
|
|
|
|
# Create table data
|
|
table_data = [
|
|
{
|
|
"Type": "Average",
|
|
"Speedup": np.mean(speedups),
|
|
**dict.fromkeys(max_config_dict),
|
|
},
|
|
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
|
|
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
|
|
]
|
|
|
|
# Print table
|
|
print(tabulate(table_data, headers="keys", tablefmt="pretty"))
|
|
|
|
|
|
def main():
|
|
seed = 123
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
results = []
|
|
# Run one timing experiment comparing nn_mha vs composite_mha
|
|
for config in tqdm(generate_experiment_configs()):
|
|
results.append(Experiment(config, run_single_experiment(config)))
|
|
|
|
print_results(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|