[Inductor] default block size for head_dim = 256 for flex attention (#125380)

## H100
### torch.bfloat16
No major change, as expected.
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     1.122 |              |             |             |             |            |             |                |
| Max     |     1.437 |            1 |          16 |         512 |         512 |        128 | head_bias   | torch.bfloat16 |
| Min     |     0.895 |            1 |          16 |        1024 |        1024 |         64 | head_bias   | torch.bfloat16 |
```
### torch.float32
Before: OOM when ```head_dim``` = 256
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype         |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------|
| Average |     2.231 |              |             |             |             |            |             |               |
| Max     |     3.760 |           16 |          16 |        4096 |        4096 |         64 | noop        | torch.float32 |
| Min     |     1.532 |            1 |          16 |         512 |         512 |        256 | causal_mask | torch.float32 |
```

## A100
### torch.bfloat16
Before:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.587 |              |             |             |             |            |               |                |
| Max     |     0.960 |            1 |          16 |         512 |         512 |         64 | noop          | torch.bfloat16 |
| Min     |     0.017 |            8 |          16 |        4096 |        4096 |        256 | relative_bias | torch.bfloat16 |
```
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.756 |              |             |             |             |            |             |                |
| Max     |     0.931 |            1 |          16 |         512 |         512 |         64 | noop        | torch.bfloat16 |
| Min     |     0.467 |           16 |          16 |        1024 |        1024 |        256 | noop        | torch.bfloat16 |
```

### torch.float32
Before: OOM when ```head_dim``` = 256
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype         |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------|
| Average |     2.386 |              |             |             |             |            |             |               |
| Max     |     7.584 |           16 |          16 |         512 |         512 |         64 | noop        | torch.float32 |
| Min     |     0.948 |            1 |          16 |         512 |         512 |        256 | causal_mask | torch.float32 |
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125380
Approved by: https://github.com/drisspg
This commit is contained in:
Yanbo Liang
2024-05-02 22:51:03 +00:00
committed by PyTorch MergeBot
parent 5c7b71dccf
commit 3b5f6b10ad
2 changed files with 33 additions and 20 deletions

View File

@ -211,7 +211,7 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128]
head_dims = [64, 128, 256]
dtypes = [
torch.bfloat16,
]