mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ScoreMod API (#121845)
# Summary This PR adds a new higher-order_op: `templated_attention`. This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention. PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)). This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation. ### Details This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined. ### API The API for a score_mod function should be as follows: ```Python def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor ``` This function receives five parameters: - `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors. - `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor. Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16) The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as: ```Python score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9)) ``` ### Examples ```Python import torch from torch.nn.attention.templated_attention import templated_attention torch.manual_seed(0) # Lets create some input tensors # The input tensor has shape (batch_size, num_heads, seq_len, head_dim) query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) # Lets create a fun new score_modification! I will call this # Checkerboard. It will reduce the score for neighboring tokens (1 step apart) # in the sequence. And increase the score for tokens 2 steps apart. For everything # else, the score will remain the same. def checkerboard(score, batch, head, token_q, token_kv): score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score) score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score) return score # Lets call templated_attention with this new score modification output = templated_attention(query, key, value, score_mod=checkerboard) compiled_templated_attention = torch.compile(templated_attention) out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard) torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2) ``` ### Future Work - This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py - Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated. - We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable - Should we error on CPU? - There are some issues with dynamic shapes - Capturing of free variables and lifting to inputs to the subgraph is not working correctly today ### Performance Comparisons generated by this benchmark: | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 5.412 | | | | | | | | | Max | 8.882 | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | | Min | 3.645 | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | | Min | 0.345 | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | For reference | Configuration | Forward Time (µ seconds) | Backend | Speedup | |-----------------------------------------------|--------------------------|------------------|---------| | Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608 | Templated Attention | 1.0 | | Compiled SDPA (No Mask) | 9928 | Math | 2.75x | | Compiled SDPA (With Mask) | 11898 | Math | 3.29x | | Compiled SDPA (With Mask) | 8704 | Memory Efficient Attention | 2.42x | | Compiled SDPA (No Mask) | 2548 | FlashAttention2 | 0.706x | The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa <details> <summary> FULL PERFORMANCE SWEEP NUMBERS </summary> | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | eager_time | compiled_time | speedup | |--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------| | 1 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 331.444 | 67.221 | 4.931 | | 1 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 335.300 | 64.187 | 5.224 | | 1 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 352.039 | 63.806 | 5.517 | | 1 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 371.699 | 711.349 | 0.523 | | 1 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 333.488 | 86.455 | 3.857 | | 1 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 322.363 | 82.469 | 3.909 | | 1 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 349.967 | 82.233 | 4.256 | | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 486.359 | 1412.453 | 0.344 | | 1 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 2794.597 | 551.188 | 5.070 | | 1 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 3965.150 | 513.101 | 7.728 | | 1 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 2408.013 | 504.759 | 4.771 | | 1 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 6850.531 | 16733.675 | 0.409 | | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 441.939 | 123.576 | 3.576 | | 8 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 560.379 | 116.710 | 4.801 | | 8 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 421.172 | 115.825 | 3.636 | | 8 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 994.492 | 2132.806 | 0.466 | | 8 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 1436.430 | 309.495 | 4.641 | | 8 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 1892.216 | 290.186 | 6.521 | | 8 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 1360.665 | 282.956 | 4.809 | | 8 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 3525.532 | 8359.702 | 0.422 | | 8 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 22026.839 | 3864.604 | 5.700 | | 8 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 31262.746 | 3609.551 | 8.661 | | 8 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 20219.079 | 3480.402 | 5.809 | | 8 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 54654.647 | 116652.357 | 0.469 | | 16 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 820.606 | 188.683 | 4.349 | | 16 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 1058.362 | 179.295 | 5.903 | | 16 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 784.372 | 175.714 | 4.464 | | 16 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 1890.792 | 4212.877 | 0.449 | | 16 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 2781.830 | 557.017 | 4.994 | | 16 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 3694.050 | 525.249 | 7.033 | | 16 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 2634.164 | 507.613 | 5.189 | | 16 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 6959.917 | 15331.116 | 0.454 | | 16 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 43889.096 | 7582.018 | 5.789 | | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 62784.293 | 7075.846 | 8.873 | | 16 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 40308.606 | 6829.587 | 5.902 | | 16 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 108892.137 | 233090.953 | 0.467 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845 Approved by: https://github.com/Chillee, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
8e98fda7a9
commit
f4e2a226aa
259
benchmarks/transformer/score_mod.py
Normal file
259
benchmarks/transformer/score_mod.py
Normal file
@ -0,0 +1,259 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from tabulate import tabulate
|
||||
from torch.nn.attention._templated_attention import _compose, _templated_attention
|
||||
from tqdm import tqdm
|
||||
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
# Needed since changing args to function causes recompiles
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
func(*args, **kwargs)
|
||||
t0 = benchmark.Timer(
|
||||
stmt="func(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "func": func},
|
||||
)
|
||||
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentConfig:
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
q_seq_len: int
|
||||
k_seq_len: int
|
||||
head_dim: int
|
||||
score_mod: Callable
|
||||
dtype: torch.dtype
|
||||
|
||||
def asdict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentResults:
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return [
|
||||
f"{self.eager_time:2f}",
|
||||
f"{self.compiled_time:2f}",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return self.config.get_entries() + self.results.get_entries()
|
||||
|
||||
def asdict(self):
|
||||
dict1 = asdict(self.config)
|
||||
dict2 = asdict(self.results)
|
||||
return {**dict1, **dict2}
|
||||
|
||||
|
||||
def generate_inputs(
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_sequence_length,
|
||||
kv_sequence_length,
|
||||
head_dim,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
|
||||
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
|
||||
|
||||
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
|
||||
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
||||
query = (
|
||||
make_q()
|
||||
.view(batch_size, q_sequence_length, num_heads, head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key = (
|
||||
make_kv()
|
||||
.view(batch_size, kv_sequence_length, num_heads, head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value = (
|
||||
make_kv()
|
||||
.view(batch_size, kv_sequence_length, num_heads, head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
return query, key, value
|
||||
|
||||
|
||||
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
query, key, value = generate_inputs(
|
||||
config.batch_size,
|
||||
config.num_heads,
|
||||
config.q_seq_len,
|
||||
config.k_seq_len,
|
||||
config.head_dim,
|
||||
config.dtype,
|
||||
device,
|
||||
)
|
||||
eager_sdpa = _templated_attention
|
||||
compiled_sdpa = torch.compile(eager_sdpa)
|
||||
|
||||
score_mod = config.score_mod
|
||||
|
||||
forward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
eager_sdpa, query, key, value, score_mod
|
||||
)
|
||||
forward_compiled_time = benchmark_torch_function_in_microseconds(
|
||||
compiled_sdpa, query, key, value, score_mod
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
eager_time=forward_eager_time,
|
||||
compiled_time=forward_compiled_time,
|
||||
)
|
||||
|
||||
|
||||
def calculate_speedup(results: ExperimentResults) -> float:
|
||||
return results.eager_time / results.compiled_time
|
||||
|
||||
|
||||
def get_func_name(func):
|
||||
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
|
||||
|
||||
|
||||
def get_average_speedups(results: List[Experiment]):
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
|
||||
# Find indices of max and min speedups
|
||||
max_speedup_index = np.argmax(speedups)
|
||||
min_speedup_index = np.argmin(speedups)
|
||||
|
||||
# Get the config dictionaries
|
||||
max_config_dict = results[max_speedup_index].config.asdict()
|
||||
min_config_dict = results[min_speedup_index].config.asdict()
|
||||
|
||||
# Extract function names from score_mod strings
|
||||
max_config_dict["score_mod"] = (
|
||||
max_config_dict["score_mod"].__name__.split("<locals>.")[-1].split(" at ")[0]
|
||||
)
|
||||
min_config_dict["score_mod"] = (
|
||||
min_config_dict["score_mod"].__name__.split("<locals>.")[-1].split(" at ")[0]
|
||||
)
|
||||
|
||||
# Create table data
|
||||
table_data = [
|
||||
{
|
||||
"Type": "Average",
|
||||
"Speedup": np.mean(speedups),
|
||||
**dict.fromkeys(max_config_dict),
|
||||
},
|
||||
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
|
||||
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
|
||||
]
|
||||
|
||||
return table_data
|
||||
|
||||
|
||||
def print_results(results: List[Experiment]):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
for key, value in experiment.asdict().items():
|
||||
if key == "eager_time" or key == "compiled_time":
|
||||
value = float(value)
|
||||
table_data[key].append(value)
|
||||
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
table_data["speedup"] = speedups
|
||||
|
||||
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
average_data = get_average_speedups(results)
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods() -> List[Callable]:
|
||||
def causal_mask(score, b, h, token_q, token_kv):
|
||||
return torch.where(token_q >= token_kv, score, float("-inf"))
|
||||
|
||||
def relative_bias(score, b, h, m, n):
|
||||
return score + (m - n)
|
||||
|
||||
def head_bias(score, b, h, m, n):
|
||||
return score + 2 * h
|
||||
|
||||
def pathological(score, b, h, m, n):
|
||||
def sin(score, b, h, m, n):
|
||||
return torch.sin(score)
|
||||
|
||||
composed_mod = _compose(*(sin for _ in range(10)))
|
||||
return composed_mod(score, b, h, m, n)
|
||||
|
||||
return [causal_mask, relative_bias, head_bias, pathological]
|
||||
|
||||
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
batch_sizes = [1, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64]
|
||||
dtypes = [
|
||||
torch.bfloat16,
|
||||
]
|
||||
score_mods = generate_score_mods()
|
||||
all_configs = []
|
||||
for (
|
||||
bsz,
|
||||
n_heads,
|
||||
(q_seq_len, kv_seq_len),
|
||||
head_dim,
|
||||
score_mod,
|
||||
dtype,
|
||||
) in itertools.product(
|
||||
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
|
||||
):
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
batch_size=bsz,
|
||||
num_heads=n_heads,
|
||||
q_seq_len=q_seq_len,
|
||||
k_seq_len=kv_seq_len,
|
||||
head_dim=head_dim,
|
||||
score_mod=score_mod,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def main():
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(generate_experiment_configs()):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user