ScoreMod API (#121845)

# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
This commit is contained in:
drisspg
2024-04-06 01:10:40 +00:00
committed by PyTorch MergeBot
parent 8e98fda7a9
commit f4e2a226aa
13 changed files with 1217 additions and 12 deletions

View File

@ -0,0 +1,259 @@
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from tabulate import tabulate
from torch.nn.attention._templated_attention import _compose, _templated_attention
from tqdm import tqdm
torch._dynamo.config.automatic_dynamic_shapes = False
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
k_seq_len: int
head_dim: int
score_mod: Callable
dtype: torch.dtype
def asdict(self):
return asdict(self)
@dataclass(frozen=True)
class ExperimentResults:
eager_time: float
compiled_time: float
def get_entries(self) -> List:
return [
f"{self.eager_time:2f}",
f"{self.compiled_time:2f}",
]
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> List:
return self.config.get_entries() + self.results.get_entries()
def asdict(self):
dict1 = asdict(self.config)
dict2 = asdict(self.results)
return {**dict1, **dict2}
def generate_inputs(
batch_size,
num_heads,
q_sequence_length,
kv_sequence_length,
head_dim,
dtype,
device,
):
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
query = (
make_q()
.view(batch_size, q_sequence_length, num_heads, head_dim)
.transpose(1, 2)
)
key = (
make_kv()
.view(batch_size, kv_sequence_length, num_heads, head_dim)
.transpose(1, 2)
)
value = (
make_kv()
.view(batch_size, kv_sequence_length, num_heads, head_dim)
.transpose(1, 2)
)
return query, key, value
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
device = torch.device("cuda")
query, key, value = generate_inputs(
config.batch_size,
config.num_heads,
config.q_seq_len,
config.k_seq_len,
config.head_dim,
config.dtype,
device,
)
eager_sdpa = _templated_attention
compiled_sdpa = torch.compile(eager_sdpa)
score_mod = config.score_mod
forward_eager_time = benchmark_torch_function_in_microseconds(
eager_sdpa, query, key, value, score_mod
)
forward_compiled_time = benchmark_torch_function_in_microseconds(
compiled_sdpa, query, key, value, score_mod
)
return ExperimentResults(
eager_time=forward_eager_time,
compiled_time=forward_compiled_time,
)
def calculate_speedup(results: ExperimentResults) -> float:
return results.eager_time / results.compiled_time
def get_func_name(func):
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
def get_average_speedups(results: List[Experiment]):
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
# Find indices of max and min speedups
max_speedup_index = np.argmax(speedups)
min_speedup_index = np.argmin(speedups)
# Get the config dictionaries
max_config_dict = results[max_speedup_index].config.asdict()
min_config_dict = results[min_speedup_index].config.asdict()
# Extract function names from score_mod strings
max_config_dict["score_mod"] = (
max_config_dict["score_mod"].__name__.split("<locals>.")[-1].split(" at ")[0]
)
min_config_dict["score_mod"] = (
min_config_dict["score_mod"].__name__.split("<locals>.")[-1].split(" at ")[0]
)
# Create table data
table_data = [
{
"Type": "Average",
"Speedup": np.mean(speedups),
**dict.fromkeys(max_config_dict),
},
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
]
return table_data
def print_results(results: List[Experiment]):
table_data = defaultdict(list)
for experiment in results:
for key, value in experiment.asdict().items():
if key == "eager_time" or key == "compiled_time":
value = float(value)
table_data[key].append(value)
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
table_data["speedup"] = speedups
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
average_data = get_average_speedups(results)
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
def generate_score_mods() -> List[Callable]:
def causal_mask(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))
def relative_bias(score, b, h, m, n):
return score + (m - n)
def head_bias(score, b, h, m, n):
return score + 2 * h
def pathological(score, b, h, m, n):
def sin(score, b, h, m, n):
return torch.sin(score)
composed_mod = _compose(*(sin for _ in range(10)))
return composed_mod(score, b, h, m, n)
return [causal_mask, relative_bias, head_bias, pathological]
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64]
dtypes = [
torch.bfloat16,
]
score_mods = generate_score_mods()
all_configs = []
for (
bsz,
n_heads,
(q_seq_len, kv_seq_len),
head_dim,
score_mod,
dtype,
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=n_heads,
q_seq_len=q_seq_len,
k_seq_len=kv_seq_len,
head_dim=head_dim,
score_mod=score_mod,
dtype=dtype,
)
)
return all_configs
def main():
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
main()