Files
pytorch/benchmarks/transformer
Yanbo Liang 3b5f6b10ad [Inductor] default block size for head_dim = 256 for flex attention (#125380)
## H100
### torch.bfloat16
No major change, as expected.
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     1.122 |              |             |             |             |            |             |                |
| Max     |     1.437 |            1 |          16 |         512 |         512 |        128 | head_bias   | torch.bfloat16 |
| Min     |     0.895 |            1 |          16 |        1024 |        1024 |         64 | head_bias   | torch.bfloat16 |
```
### torch.float32
Before: OOM when ```head_dim``` = 256
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype         |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------|
| Average |     2.231 |              |             |             |             |            |             |               |
| Max     |     3.760 |           16 |          16 |        4096 |        4096 |         64 | noop        | torch.float32 |
| Min     |     1.532 |            1 |          16 |         512 |         512 |        256 | causal_mask | torch.float32 |
```

## A100
### torch.bfloat16
Before:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.587 |              |             |             |             |            |               |                |
| Max     |     0.960 |            1 |          16 |         512 |         512 |         64 | noop          | torch.bfloat16 |
| Min     |     0.017 |            8 |          16 |        4096 |        4096 |        256 | relative_bias | torch.bfloat16 |
```
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.756 |              |             |             |             |            |             |                |
| Max     |     0.931 |            1 |          16 |         512 |         512 |         64 | noop        | torch.bfloat16 |
| Min     |     0.467 |           16 |          16 |        1024 |        1024 |        256 | noop        | torch.bfloat16 |
```

### torch.float32
Before: OOM when ```head_dim``` = 256
After:
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype         |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------|
| Average |     2.386 |              |             |             |             |            |             |               |
| Max     |     7.584 |           16 |          16 |         512 |         512 |         64 | noop        | torch.float32 |
| Min     |     0.948 |            1 |          16 |         512 |         512 |        256 | causal_mask | torch.float32 |
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125380
Approved by: https://github.com/drisspg
2024-05-02 22:51:07 +00:00
..
2024-04-26 15:35:53 +00:00