AutoHeuristic: Collect data for mixed_mm (#131611)

This PR introduces a script that can be used to collect data for mixed_mm to learn a heuristic with AutoHeuristic. This PR also includes the following things:

Move pad_mm related AutoHeuristic files into subdirectory
Introduce an interface benchmark_runner.py that can be subclassed to introduce new scripts to run benchmarks in order to collect data with AutoHeuristic (see gen_data_pad_mm.py and gen_data_mixed_mm.py).
The idea behind the interface is that, in the end, it hopefully makes it easier to collect data for new optimizations, and thus makes it easier to learn a heuristic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131611
Approved by: https://github.com/eellison
ghstack dependencies: #131610
This commit is contained in:
Alnis Murtovi
2024-07-31 10:43:12 -07:00
committed by PyTorch MergeBot
parent f8b6e91840
commit d3cefc9e3a
10 changed files with 438 additions and 216 deletions

View File

@ -0,0 +1,87 @@
import argparse
import random
import time
from abc import abstractmethod
from typing import Any, Tuple
from tqdm import tqdm # type: ignore[import-untyped]
import torch
class BenchmarkRunner:
"""
BenchmarkRunner is a base class for all benchmark runners. It provides an interface to run benchmarks in order to
collect data with AutoHeuristic.
"""
def __init__(self, name: str) -> None:
self.name = name
self.parser = argparse.ArgumentParser()
self.add_base_arguments()
self.args = None
def add_base_arguments(self) -> None:
self.parser.add_argument(
"--device",
type=int,
default=None,
help="torch.cuda.set_device(device) will be used",
)
self.parser.add_argument(
"--use-heuristic",
action="store_true",
help="Use learned heuristic instead of collecting data.",
)
self.parser.add_argument(
"-o",
type=str,
default="ah_data.txt",
help="Path to file where AutoHeuristic will log results.",
)
self.parser.add_argument(
"--num-samples",
type=int,
default=1000,
help="Number of samples to collect.",
)
self.parser.add_argument(
"--num-reps",
type=int,
default=3,
help="Number of measurements to collect for each input.",
)
def run(self) -> None:
torch.set_default_device("cuda")
args = self.parser.parse_args()
if args.use_heuristic:
torch._inductor.config.autoheuristic_use = self.name
else:
torch._inductor.config.autoheuristic_collect = self.name
torch._inductor.config.autoheuristic_log_path = args.o
if args.device is not None:
torch.cuda.set_device(args.device)
random.seed(time.time())
self.main(args.num_samples, args.num_reps)
def get_random_between_pow2(self, min_power2: int, max_power2: int) -> int:
i = random.randint(min_power2, max_power2 - 1)
lower = 2**i + 1
upper = 2 ** (i + 1) - 1
assert lower <= upper, "lower must not be greater than upper"
return random.randint(lower, upper)
@abstractmethod
def run_benchmark(self, *args: Any) -> None:
...
@abstractmethod
def create_input(self) -> Tuple[Any, ...]:
...
def main(self, num_samples: int, num_reps: int) -> None:
for _ in tqdm(range(num_samples)):
input = self.create_input()
for _ in range(num_reps):
self.run_benchmark(*input)

View File

@ -0,0 +1,44 @@
import random
from typing import Any, Tuple
import torch
def transpose_tensors(p_transpose_both: float = 0.05) -> Tuple[bool, bool]:
transpose_both = random.choices(
[True, False], [p_transpose_both, 1 - p_transpose_both]
)[0]
if transpose_both:
return (True, True)
transpose_left = (True, False)
transpose_right = (False, True)
no_transpose = (False, False)
return random.choices([transpose_left, transpose_right, no_transpose])[0]
def fits_in_memory(dtype: Any, m: int, k: int, n: int) -> Any:
threshold_memory = torch.cuda.get_device_properties(0).total_memory / 4
# dividing by 4 beause we otherwise sometimes run out of memory, I assume because
# inductor creates copies of tensors for benchmarking?
return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory
def get_mm_tensors(
m: int,
k: int,
n: int,
transpose_left: bool,
transpose_right: bool,
dtype_left: Any,
dtype_right: Any,
) -> Tuple[Any, Any]:
if transpose_left:
a = torch.randn(k, m, dtype=dtype_left).t()
else:
a = torch.randn(m, k, dtype=dtype_left)
if transpose_right:
b = torch.randn(n, k, dtype=dtype_right).t()
else:
b = torch.randn(k, n, dtype=dtype_right)
return (a, b)

View File

@ -0,0 +1,145 @@
# mypy: ignore-errors
import os
import random
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import Any, Tuple
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
from benchmark_utils import ( # type: ignore[import-not-found]
fits_in_memory,
get_mm_tensors,
)
import torch
from torch._inductor.utils import fresh_inductor_cache
class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
"""
BenchmarkRunner for mixed mm. Used to generate collect training data with AutoHeuristic to learn a heuristic.
Currently, we are generating inputs with the following restrictions:
- m <= 128, and n and k >= 1024 (for these inputs one of the triton kernels wins in most cases)
- k % 256 == 0 (if k is not a multiple of the block size, this can have a huge negative impact on performance)
- mat1 not transposed
- mat2 transposed
This allows us to learn a heuristic that works well e.g. for gpt-fast.
"""
def __init__(self) -> None:
super().__init__("mixed_mm")
def create_input(self) -> Tuple[Any, ...]:
dtype1, dtype2 = self.get_dtypes()
m, k, n = self.get_m_k_n(dtype1)
transpose_left, transpose_right = False, True
return (m, k, n, transpose_left, transpose_right, dtype1, dtype2)
def run_benchmark(
self,
m: int,
k: int,
n: int,
transpose_left: bool,
transpose_right: bool,
dtype_left: Any,
dtype_right: Any,
) -> Any:
a, b = get_mm_tensors(
m,
k,
n,
transpose_left,
transpose_right,
dtype_left=dtype_left,
dtype_right=torch.float32,
)
b = b.to(dtype=dtype_right)
with fresh_inductor_cache():
def mixed_mm(A, B):
return torch.mm(A, B.to(A.dtype))
cf = torch.compile(mixed_mm, mode="max-autotune-no-cudagraphs")
cf(a, b)
torch.compiler.reset()
def random_multiple_of_128(self, min_num=7, max_num=17):
ran_pow2 = random.randint(min_num, max_num - 1)
start = (2**ran_pow2) // 128
end = (2 ** (ran_pow2 + 1)) // 128
random_multiple = random.randint(start, end)
return random_multiple * 128
def get_random_pow2(self, min_power2: int, max_power2: int):
return 2 ** random.randint(min_power2, max_power2)
def get_distr_type(self) -> str:
# 85%: choose a random multiple of 128 between 2^10 and 2^17
# 10%: choose a random power of 2 between 2^10 and 2^17 favoring larger values
# 4%: choose a random number between 1024 and 131072
# 1%: choose a random number between 2^i and 2^(i+1) with i in [10, 16]
return random.choices(
["mult_128", "pow2", "uniform", "uniform-between-pow2"],
[0.85, 0.1, 0.04, 0.01],
)[0]
def get_random_dim(self):
distr_type = self.get_distr_type()
if distr_type == "mult_128":
return self.random_multiple_of_128(min_num=10, max_num=17)
if distr_type == "pow2":
return self.get_random_pow2(min_power2=10, max_power2=17)
elif distr_type == "uniform-between-pow2":
return self.get_random_between_pow2(min_power2=10, max_power2=17)
elif distr_type == "uniform":
return random.randint(1024, 131072)
print(f"random_type {distr_type} not supported")
sys.exit(1)
def get_random_num_small(self) -> int:
pow2 = random.choices([True, False], [0.75, 0.25])[0]
if pow2:
return 2 ** random.randint(1, 7)
else:
return self.get_random_between_pow2(1, 7)
def get_m_k_n(self, dtype: Any) -> Tuple[int, int, int]:
numel_max = 2**31
# repeat until tensors fit in memory
while True:
m = self.get_random_num_small()
k = self.get_random_dim()
n = self.get_random_dim()
if k % 256 != 0:
continue
assert k >= 1024 and n >= 1024, "k and n must be at least 1024"
if m * k >= numel_max or m * n >= numel_max or k * n >= numel_max:
# autotuning will not happen for tensors that are this large
continue
if fits_in_memory(dtype, m, k, n):
return (m, k, n)
def get_dtypes(self) -> Any:
while True:
dtype_floats = [torch.float16, torch.bfloat16]
dtype_ints = [torch.int8, torch.uint8]
mat1_dtype = random.choices(dtype_floats)[0]
mat2_dtype = random.choices(dtype_ints)[0]
if mat1_dtype == torch.bfloat16 and mat2_dtype == torch.uint8:
# this combination seems to cause issues with mixed_mm
continue
return (mat1_dtype, mat2_dtype)
if __name__ == "__main__":
runner = BenchmarkRunnerMixedMM()
runner.run()

View File

@ -0,0 +1,157 @@
import os
import random
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import Any, Tuple
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
from benchmark_utils import ( # type: ignore[import-not-found]
fits_in_memory,
get_mm_tensors,
transpose_tensors,
)
import torch
from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found]
get_alignment_size_dtype,
)
from torch._inductor.utils import fresh_inductor_cache
class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
"""
BenchmarkRunner for pad_mm. Used to generate collect training data with AutoHeuristic to learn a heuristic.
"""
def __init__(self) -> None:
super().__init__("pad_mm")
def create_input(self) -> Tuple[Any, ...]:
dtype = self.get_dtype()
self.set_precision(dtype)
m, k, n = self.get_m_k_n(dtype)
(transpose_left, transpose_right) = transpose_tensors()
prepadded_left = self.prepadded()
prepadded_right = self.prepadded()
return (
m,
k,
n,
transpose_left,
transpose_right,
dtype,
prepadded_left,
prepadded_right,
)
def run_benchmark(
self,
m: int,
k: int,
n: int,
transpose_left: bool,
transpose_right: bool,
dtype: Any,
prepadded_left: bool,
prepadded_right: bool,
) -> None:
a, b = get_mm_tensors(
m,
k,
n,
transpose_left,
transpose_right,
dtype_left=dtype,
dtype_right=dtype,
)
print("Benchmarking the following input:")
print(f"m={m} k={k} n={n} dtype={dtype}")
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
with fresh_inductor_cache():
def mm(a: Any, b: Any) -> Any:
return torch.mm(a, b)
def mm_mat1_prepadded(a: Any, b: Any) -> Any:
return torch.mm(a + 1, b)
def mm_mat2_prepadded(a: Any, b: Any) -> Any:
return torch.mm(a, b + 1)
def mm_mat1_mat2_prepadded(a: Any, b: Any) -> Any:
return torch.mm(a + 1, b + 1)
if prepadded_left and prepadded_right:
cf = torch.compile(mm_mat1_mat2_prepadded)
elif prepadded_left:
cf = torch.compile(mm_mat1_prepadded)
elif prepadded_right:
cf = torch.compile(mm_mat2_prepadded)
else:
cf = torch.compile(mm)
cf(a, b)
torch.compiler.reset()
def get_random_dim(
self, min_power2: int = 1, max_power2: int = 16, p_unaligned: float = 0.25
) -> int:
aligned = random.choices([True, False], [1 - p_unaligned, p_unaligned])[0]
if aligned:
return 2 ** random.randint(min_power2, max_power2) # type: ignore[no-any-return]
else:
# choose a random number between 2^i and 2^(i+1)
return self.get_random_between_pow2(min_power2, max_power2) # type: ignore[no-any-return]
def is_aligned(self, dim: int, align_size: int) -> bool:
return dim % align_size == 0
def get_m_k_n(self, dtype: Any) -> Tuple[int, int, int]:
uniform = random.choices([True, False])[0]
align_size = get_alignment_size_dtype(dtype)
# repeat until tensors fit in memory
while True:
if uniform:
m = random.randint(1, 65536)
k = random.randint(1, 65536)
n = random.randint(1, 65536)
else:
m = self.get_random_dim()
k = self.get_random_dim()
n = self.get_random_dim()
if all(self.is_aligned(dim, align_size) for dim in [m, k, n]):
# skip if already aligned
continue
if fits_in_memory(dtype, m, k, n):
return (m, k, n)
def prepadded(self, p_prepadded: float = 0.2) -> bool:
# p_prepadded: probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking
return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0]
def get_dtype(self) -> Any:
dtype_choices = [torch.float16, torch.bfloat16, torch.float32]
return random.choices(dtype_choices)[0]
def set_precision(self, dtype: Any, p_float32_prec_highest: float = 0.8) -> None:
if dtype == torch.float32:
precisions = ["high", "highest"]
weights = [1 - p_float32_prec_highest, p_float32_prec_highest]
precision = random.choices(precisions, weights)[0]
else:
precision = "high"
torch.set_float32_matmul_precision(precision)
if __name__ == "__main__":
runner = BenchmarkRunnerPadMM()
runner.run()

View File

@ -1,4 +1,9 @@
# mypy: ignore-errors
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from train import AHTrain

View File

@ -1,216 +0,0 @@
import argparse
import random
import time
from typing import Any, Tuple
from tqdm import tqdm # type: ignore[import-untyped]
import torch
torch.set_default_device("cuda")
from torch._inductor.fx_passes.pad_mm import get_alignment_size_dtype
from torch._inductor.utils import fresh_inductor_cache
# A100: 81920MiB
# without a threshold we sometimes run out of memory
threshold_memory = 85899345920 / 4
# probability that a dimension is unaligned
p_unaligned = 0.25
# probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking
p_prepadded = 0.2
# probability that we pick from uniform distribution
p_uniform = 0.5
p_float32_prec_highest = 0.8
def benchmark(
m: int,
k: int,
n: int,
transpose_left: bool,
transpose_right: bool,
dtype: Any,
prepadded_left: bool,
prepadded_right: bool,
) -> None:
if transpose_left:
a = torch.randn(k, m, dtype=dtype).t()
else:
a = torch.randn(m, k, dtype=dtype)
if transpose_right:
b = torch.randn(n, k, dtype=dtype).t()
else:
b = torch.randn(k, n, dtype=dtype)
with fresh_inductor_cache():
def mm(a: Any, b: Any) -> Any:
return torch.mm(a, b)
def mm_mat1_prepadded(a: Any, b: Any) -> Any:
return torch.mm(a + 1, b)
def mm_mat2_prepadded(a: Any, b: Any) -> Any:
return torch.mm(a, b + 1)
def mm_mat1_mat2_prepadded(a: Any, b: Any) -> Any:
return torch.mm(a + 1, b + 1)
if prepadded_left and prepadded_right:
cf = torch.compile(mm_mat1_mat2_prepadded)
elif prepadded_left:
cf = torch.compile(mm_mat1_prepadded)
elif prepadded_right:
cf = torch.compile(mm_mat2_prepadded)
else:
cf = torch.compile(mm)
cf(a, b)
torch.compiler.reset()
def fits_in_memory(dtype: Any, m: int, k: int, n: int) -> Any:
return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory
def get_random_dim(min_power2: int = 1, max_power2: int = 16) -> int:
aligned = random.choices([True, False], [1 - p_unaligned, p_unaligned])[0]
if aligned:
return 2 ** random.randint(min_power2, max_power2) # type: ignore[no-any-return]
else:
# choose a random number between 2^i and 2^(i+1)
i = random.randint(min_power2, max_power2 - 1)
lower = 2**i + 1
upper = 2 ** (i + 1) - 1
return random.randint(lower, upper)
def is_aligned(dim: int, align_size: int) -> bool:
return dim % align_size == 0
def get_m_k_n(dtype: Any) -> Tuple[int, int, int]:
uniform = random.choices([True, False], [0.5, 0.5])[0]
align_size = get_alignment_size_dtype(dtype)
# repeat until tensors fit in memory
while True:
if uniform:
m = random.randint(1, 65536)
k = random.randint(1, 65536)
n = random.randint(1, 65536)
else:
m = get_random_dim()
k = get_random_dim()
n = get_random_dim()
if all(is_aligned(dim, align_size) for dim in [m, k, n]):
# skip if already aligned
continue
if fits_in_memory(dtype, m, k, n):
return (m, k, n)
def transpose_tensors() -> Tuple[bool, bool]:
p_transpose_both = 0.05
transpose_both = random.choices(
[True, False], [p_transpose_both, 1 - p_transpose_both]
)[0]
if transpose_both:
return (True, True)
transpose_left = (True, False)
transpose_right = (False, True)
no_transpose = (False, False)
return random.choices([transpose_left, transpose_right, no_transpose])[0]
def prepadded() -> bool:
return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0]
def get_dtype() -> Any:
dtype_choices = [torch.float16, torch.bfloat16, torch.float32]
return random.choices(dtype_choices)[0]
def set_precision(dtype: Any) -> None:
if dtype == torch.float32:
precisions = ["high", "highest"]
weights = [1 - p_float32_prec_highest, p_float32_prec_highest]
precision = random.choices(precisions, weights)[0]
else:
precision = "high"
torch.set_float32_matmul_precision(precision)
def main(num_samples: int) -> None:
for i in tqdm(range(num_samples)):
dtype = get_dtype()
set_precision(dtype)
m, k, n = get_m_k_n(dtype)
(transpose_left, transpose_right) = transpose_tensors()
prepadded_left = prepadded()
prepadded_right = prepadded()
print("Benchmarking the following input:")
print(f"m={m} k={k} n={n} dtype={dtype}")
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
for i in range(3):
benchmark(
m,
k,
n,
transpose_left,
transpose_right,
dtype,
prepadded_left,
prepadded_right,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--device",
type=int,
default=None,
help="torch.cuda.set_device(device) will be used",
)
parser.add_argument(
"--use-heuristic",
action="store_true",
help="Use learned heuristic instead of collecting data.",
)
parser.add_argument(
"-o",
type=str,
default="a100_data.txt",
help="Path to file where AutoHeuristic will log results.",
)
parser.add_argument(
"--num-samples",
type=int,
default=1000,
help="Number of samples to collect.",
)
args = parser.parse_args()
if args.use_heuristic:
torch._inductor.config.autoheuristic_use = "pad_mm"
else:
torch._inductor.config.autoheuristic_collect = "pad_mm"
torch._inductor.config.autoheuristic_log_path = args.o
if args.device is not None:
torch.cuda.set_device(args.device)
random.seed(time.time())
main(args.num_samples)