AutoHeuristic: mixed_mm H100 heuristic (#132685)

H100 heuristic for mixed_mm. Performance looks similar to A100 heuristic.
```
  set     crit  max_depth  min_samples_leaf  correct  wrong  unsure  total  wrong_max_spdup  wrong_gman_spdup  max_spdup_default  gman_spdup_default  max_slowdown_default  non_default_preds  default_better
train  entropy          5              0.01     1562    604     145   2311         1.522201          1.077722          10.399141            3.134170              1.034802               2061               2
 test  entropy          5              0.01      361    164      24    549         1.443590          1.079169           8.159173            3.105360              1.197973                500               2
```

gpt-fast speedups
|batch size|prompt length| fallback    |  heuristic  | speedup |
|----------|-------------|------------:|------------:|--------:|
|     1    |      7      |      109.95  |       220.63|  2      |
|     1    |     11      |      109.65  | 	    210.92|  1.92   |
|     4    |      7      |       149.04 |       625.80|  4.19   |
|     4    |     11      |       149.56 |       494.64|  3.30   |
|     8    |      7      |       293.68 |       956.72|  3.25   |
|     8    |     11      |       294.48 |       925.60|  3.14   |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132685
Approved by: https://github.com/eellison
This commit is contained in:
Alnis Murtovi
2024-08-07 11:10:43 -07:00
committed by PyTorch MergeBot
parent c327710a87
commit 383f2ac914
3 changed files with 163 additions and 4 deletions

View File

@ -0,0 +1,148 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)
class MixedMMH100(LearnedHeuristicDecision):
def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 232448
and str(metadata.device_capa) == "(9, 0)"
)
def get_confidence_threshold(self) -> float:
return 0.0
def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None
def fill_choices(self) -> None:
self.choices.append('extern_fallback_mixed_mm')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
def get_name(self) -> str:
return 'mixed_mm'
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
if context.get_value('arith_intensity') <= 15.988086223602295:
if context.get_value('n') <= 25280.0:
if context.get_value('n') <= 1344.0:
if context.get_value('mat1_stride_0') <= 7808.0:
return [(0.581, 7), (0.419, 6)]
else:
if context.get_value('m*n') <= 7680.0:
return [(0.875, 0), (0.125, 6)]
else:
return [(0.833, 0), (0.167, 7)]
else:
if context.get_value('n') <= 8512.0:
if str(context.get_value('mat2_dtype')) != 'torch.int8':
return [(0.763, 6), (0.237, 7)]
else:
return [(0.725, 7), (0.275, 6)]
else:
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
else:
return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
else:
if context.get_value('n') <= 42254.0:
if context.get_value('n') <= 33856.0:
if context.get_value('k*n') <= 68157440.0:
return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
else:
return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
else:
return [(0.659, 5), (0.341, 6)]
else:
if context.get_value('k*n') <= 326052992.0:
if context.get_value('n') <= 55232.0:
return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
else:
return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
else:
if context.get_value('n') <= 57024.0:
return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
else:
return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
else:
if context.get_value('m*n') <= 543936.0:
if str(context.get_value('17LEQmLEQ32')) != 'True':
if context.get_value('m*n') <= 262272.0:
if context.get_value('n') <= 1592.5:
return [(0.860, 0), (0.140, 9)]
else:
return None
else:
if context.get_value('m*k') <= 1294336.0:
return [(0.833, 17), (0.150, 18), (0.017, 15)]
else:
return [(0.917, 17), (0.083, 8)]
else:
if context.get_value('n') <= 12416.0:
if context.get_value('m*n') <= 43008.0:
return None
else:
return [(0.853, 14), (0.147, 9)]
else:
return [(0.625, 12), (0.375, 14)]
else:
if context.get_value('m') <= 32.5:
if context.get_value('mat2_stride_1') <= 6656.0:
if context.get_value('n') <= 69184.0:
return [(0.611, 12), (0.361, 14), (0.028, 13)]
else:
return [(1.000, 12)]
else:
if context.get_value('mat2_stride_1') <= 20864.0:
return [(1.000, 12)]
else:
return [(0.958, 12), (0.042, 9)]
else:
if context.get_value('m*n') <= 1085440.0:
if context.get_value('n') <= 9152.0:
return [(1.000, 18)]
else:
return [(0.780, 18), (0.160, 16), (0.060, 20)]
else:
if context.get_value('m') <= 67.0:
return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
else:
return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]

View File

@ -0,0 +1,5 @@
#!/bin/bash
data="mixedmm_h100_data.txt"
python train_decision_mixedmm.py ${data} --heuristic-name MixedMMH100

View File

@ -1,6 +1,12 @@
#!/bin/bash #!/bin/bash
a100_data='https://github.com/AlnisM/autoheuristic-datasets/raw/main/mixedmm_a100_data.zip' base_url='https://github.com/AlnisM/autoheuristic-datasets/raw/main/'
wget ${a100_data} a100_data='mixedmm_a100_data.zip'
unzip mixedmm_a100_data.zip h100_data='mixedmm_h100_data.zip'
rm mixedmm_a100_data.zip datasets=("${a100_data}" "${h100_data}")
for dataset in "${datasets[@]}"; do
url="${base_url}${dataset}"
wget ${url}
unzip ${dataset}
rm ${dataset}
done