mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] default block size for head_dim = 256 for flex attention (#125380)
## H100 ### torch.bfloat16 No major change, as expected. ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------| | Average | 1.122 | | | | | | | | | Max | 1.437 | 1 | 16 | 512 | 512 | 128 | head_bias | torch.bfloat16 | | Min | 0.895 | 1 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | ``` ### torch.float32 Before: OOM when ```head_dim``` = 256 After: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------| | Average | 2.231 | | | | | | | | | Max | 3.760 | 16 | 16 | 4096 | 4096 | 64 | noop | torch.float32 | | Min | 1.532 | 1 | 16 | 512 | 512 | 256 | causal_mask | torch.float32 | ``` ## A100 ### torch.bfloat16 Before: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 0.587 | | | | | | | | | Max | 0.960 | 1 | 16 | 512 | 512 | 64 | noop | torch.bfloat16 | | Min | 0.017 | 8 | 16 | 4096 | 4096 | 256 | relative_bias | torch.bfloat16 | ``` After: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------| | Average | 0.756 | | | | | | | | | Max | 0.931 | 1 | 16 | 512 | 512 | 64 | noop | torch.bfloat16 | | Min | 0.467 | 16 | 16 | 1024 | 1024 | 256 | noop | torch.bfloat16 | ``` ### torch.float32 Before: OOM when ```head_dim``` = 256 After: ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|---------------| | Average | 2.386 | | | | | | | | | Max | 7.584 | 16 | 16 | 512 | 512 | 64 | noop | torch.float32 | | Min | 0.948 | 1 | 16 | 512 | 512 | 256 | causal_mask | torch.float32 | ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125380 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c7b71dccf
commit
3b5f6b10ad
@ -211,7 +211,7 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
batch_sizes = [1, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64, 128]
|
||||
head_dims = [64, 128, 256]
|
||||
dtypes = [
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
Reference in New Issue
Block a user