mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
AutoHeuristic: Collect data for mixed_mm (#131611)
This PR introduces a script that can be used to collect data for mixed_mm to learn a heuristic with AutoHeuristic. This PR also includes the following things: Move pad_mm related AutoHeuristic files into subdirectory Introduce an interface benchmark_runner.py that can be subclassed to introduce new scripts to run benchmarks in order to collect data with AutoHeuristic (see gen_data_pad_mm.py and gen_data_mixed_mm.py). The idea behind the interface is that, in the end, it hopefully makes it easier to collect data for new optimizations, and thus makes it easier to learn a heuristic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131611 Approved by: https://github.com/eellison ghstack dependencies: #131610
This commit is contained in:
committed by
PyTorch MergeBot
parent
f8b6e91840
commit
d3cefc9e3a
87
torchgen/_autoheuristic/benchmark_runner.py
Normal file
87
torchgen/_autoheuristic/benchmark_runner.py
Normal file
@ -0,0 +1,87 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Tuple
|
||||
|
||||
from tqdm import tqdm # type: ignore[import-untyped]
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BenchmarkRunner:
|
||||
"""
|
||||
BenchmarkRunner is a base class for all benchmark runners. It provides an interface to run benchmarks in order to
|
||||
collect data with AutoHeuristic.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.add_base_arguments()
|
||||
self.args = None
|
||||
|
||||
def add_base_arguments(self) -> None:
|
||||
self.parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=None,
|
||||
help="torch.cuda.set_device(device) will be used",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--use-heuristic",
|
||||
action="store_true",
|
||||
help="Use learned heuristic instead of collecting data.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-o",
|
||||
type=str,
|
||||
default="ah_data.txt",
|
||||
help="Path to file where AutoHeuristic will log results.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of samples to collect.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--num-reps",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of measurements to collect for each input.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
args = self.parser.parse_args()
|
||||
if args.use_heuristic:
|
||||
torch._inductor.config.autoheuristic_use = self.name
|
||||
else:
|
||||
torch._inductor.config.autoheuristic_collect = self.name
|
||||
torch._inductor.config.autoheuristic_log_path = args.o
|
||||
if args.device is not None:
|
||||
torch.cuda.set_device(args.device)
|
||||
random.seed(time.time())
|
||||
self.main(args.num_samples, args.num_reps)
|
||||
|
||||
def get_random_between_pow2(self, min_power2: int, max_power2: int) -> int:
|
||||
i = random.randint(min_power2, max_power2 - 1)
|
||||
lower = 2**i + 1
|
||||
upper = 2 ** (i + 1) - 1
|
||||
assert lower <= upper, "lower must not be greater than upper"
|
||||
return random.randint(lower, upper)
|
||||
|
||||
@abstractmethod
|
||||
def run_benchmark(self, *args: Any) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_input(self) -> Tuple[Any, ...]:
|
||||
...
|
||||
|
||||
def main(self, num_samples: int, num_reps: int) -> None:
|
||||
for _ in tqdm(range(num_samples)):
|
||||
input = self.create_input()
|
||||
for _ in range(num_reps):
|
||||
self.run_benchmark(*input)
|
44
torchgen/_autoheuristic/benchmark_utils.py
Normal file
44
torchgen/_autoheuristic/benchmark_utils.py
Normal file
@ -0,0 +1,44 @@
|
||||
import random
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def transpose_tensors(p_transpose_both: float = 0.05) -> Tuple[bool, bool]:
|
||||
transpose_both = random.choices(
|
||||
[True, False], [p_transpose_both, 1 - p_transpose_both]
|
||||
)[0]
|
||||
if transpose_both:
|
||||
return (True, True)
|
||||
transpose_left = (True, False)
|
||||
transpose_right = (False, True)
|
||||
no_transpose = (False, False)
|
||||
return random.choices([transpose_left, transpose_right, no_transpose])[0]
|
||||
|
||||
|
||||
def fits_in_memory(dtype: Any, m: int, k: int, n: int) -> Any:
|
||||
threshold_memory = torch.cuda.get_device_properties(0).total_memory / 4
|
||||
# dividing by 4 beause we otherwise sometimes run out of memory, I assume because
|
||||
# inductor creates copies of tensors for benchmarking?
|
||||
return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory
|
||||
|
||||
|
||||
def get_mm_tensors(
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
dtype_left: Any,
|
||||
dtype_right: Any,
|
||||
) -> Tuple[Any, Any]:
|
||||
if transpose_left:
|
||||
a = torch.randn(k, m, dtype=dtype_left).t()
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=dtype_left)
|
||||
|
||||
if transpose_right:
|
||||
b = torch.randn(n, k, dtype=dtype_right).t()
|
||||
else:
|
||||
b = torch.randn(k, n, dtype=dtype_right)
|
||||
return (a, b)
|
145
torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py
Normal file
145
torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py
Normal file
@ -0,0 +1,145 @@
|
||||
# mypy: ignore-errors
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
|
||||
from benchmark_utils import ( # type: ignore[import-not-found]
|
||||
fits_in_memory,
|
||||
get_mm_tensors,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
|
||||
|
||||
class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||
"""
|
||||
BenchmarkRunner for mixed mm. Used to generate collect training data with AutoHeuristic to learn a heuristic.
|
||||
Currently, we are generating inputs with the following restrictions:
|
||||
- m <= 128, and n and k >= 1024 (for these inputs one of the triton kernels wins in most cases)
|
||||
- k % 256 == 0 (if k is not a multiple of the block size, this can have a huge negative impact on performance)
|
||||
- mat1 not transposed
|
||||
- mat2 transposed
|
||||
This allows us to learn a heuristic that works well e.g. for gpt-fast.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("mixed_mm")
|
||||
|
||||
def create_input(self) -> Tuple[Any, ...]:
|
||||
dtype1, dtype2 = self.get_dtypes()
|
||||
m, k, n = self.get_m_k_n(dtype1)
|
||||
transpose_left, transpose_right = False, True
|
||||
return (m, k, n, transpose_left, transpose_right, dtype1, dtype2)
|
||||
|
||||
def run_benchmark(
|
||||
self,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
dtype_left: Any,
|
||||
dtype_right: Any,
|
||||
) -> Any:
|
||||
a, b = get_mm_tensors(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
dtype_left=dtype_left,
|
||||
dtype_right=torch.float32,
|
||||
)
|
||||
b = b.to(dtype=dtype_right)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
|
||||
def mixed_mm(A, B):
|
||||
return torch.mm(A, B.to(A.dtype))
|
||||
|
||||
cf = torch.compile(mixed_mm, mode="max-autotune-no-cudagraphs")
|
||||
cf(a, b)
|
||||
torch.compiler.reset()
|
||||
|
||||
def random_multiple_of_128(self, min_num=7, max_num=17):
|
||||
ran_pow2 = random.randint(min_num, max_num - 1)
|
||||
start = (2**ran_pow2) // 128
|
||||
end = (2 ** (ran_pow2 + 1)) // 128
|
||||
random_multiple = random.randint(start, end)
|
||||
return random_multiple * 128
|
||||
|
||||
def get_random_pow2(self, min_power2: int, max_power2: int):
|
||||
return 2 ** random.randint(min_power2, max_power2)
|
||||
|
||||
def get_distr_type(self) -> str:
|
||||
# 85%: choose a random multiple of 128 between 2^10 and 2^17
|
||||
# 10%: choose a random power of 2 between 2^10 and 2^17 favoring larger values
|
||||
# 4%: choose a random number between 1024 and 131072
|
||||
# 1%: choose a random number between 2^i and 2^(i+1) with i in [10, 16]
|
||||
return random.choices(
|
||||
["mult_128", "pow2", "uniform", "uniform-between-pow2"],
|
||||
[0.85, 0.1, 0.04, 0.01],
|
||||
)[0]
|
||||
|
||||
def get_random_dim(self):
|
||||
distr_type = self.get_distr_type()
|
||||
if distr_type == "mult_128":
|
||||
return self.random_multiple_of_128(min_num=10, max_num=17)
|
||||
if distr_type == "pow2":
|
||||
return self.get_random_pow2(min_power2=10, max_power2=17)
|
||||
elif distr_type == "uniform-between-pow2":
|
||||
return self.get_random_between_pow2(min_power2=10, max_power2=17)
|
||||
elif distr_type == "uniform":
|
||||
return random.randint(1024, 131072)
|
||||
print(f"random_type {distr_type} not supported")
|
||||
sys.exit(1)
|
||||
|
||||
def get_random_num_small(self) -> int:
|
||||
pow2 = random.choices([True, False], [0.75, 0.25])[0]
|
||||
if pow2:
|
||||
return 2 ** random.randint(1, 7)
|
||||
else:
|
||||
return self.get_random_between_pow2(1, 7)
|
||||
|
||||
def get_m_k_n(self, dtype: Any) -> Tuple[int, int, int]:
|
||||
numel_max = 2**31
|
||||
|
||||
# repeat until tensors fit in memory
|
||||
while True:
|
||||
m = self.get_random_num_small()
|
||||
k = self.get_random_dim()
|
||||
n = self.get_random_dim()
|
||||
if k % 256 != 0:
|
||||
continue
|
||||
|
||||
assert k >= 1024 and n >= 1024, "k and n must be at least 1024"
|
||||
|
||||
if m * k >= numel_max or m * n >= numel_max or k * n >= numel_max:
|
||||
# autotuning will not happen for tensors that are this large
|
||||
continue
|
||||
|
||||
if fits_in_memory(dtype, m, k, n):
|
||||
return (m, k, n)
|
||||
|
||||
def get_dtypes(self) -> Any:
|
||||
while True:
|
||||
dtype_floats = [torch.float16, torch.bfloat16]
|
||||
dtype_ints = [torch.int8, torch.uint8]
|
||||
mat1_dtype = random.choices(dtype_floats)[0]
|
||||
mat2_dtype = random.choices(dtype_ints)[0]
|
||||
if mat1_dtype == torch.bfloat16 and mat2_dtype == torch.uint8:
|
||||
# this combination seems to cause issues with mixed_mm
|
||||
continue
|
||||
return (mat1_dtype, mat2_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = BenchmarkRunnerMixedMM()
|
||||
runner.run()
|
157
torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py
Normal file
157
torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py
Normal file
@ -0,0 +1,157 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
|
||||
from benchmark_utils import ( # type: ignore[import-not-found]
|
||||
fits_in_memory,
|
||||
get_mm_tensors,
|
||||
transpose_tensors,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found]
|
||||
get_alignment_size_dtype,
|
||||
)
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
|
||||
|
||||
class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||
"""
|
||||
BenchmarkRunner for pad_mm. Used to generate collect training data with AutoHeuristic to learn a heuristic.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("pad_mm")
|
||||
|
||||
def create_input(self) -> Tuple[Any, ...]:
|
||||
dtype = self.get_dtype()
|
||||
self.set_precision(dtype)
|
||||
m, k, n = self.get_m_k_n(dtype)
|
||||
|
||||
(transpose_left, transpose_right) = transpose_tensors()
|
||||
prepadded_left = self.prepadded()
|
||||
prepadded_right = self.prepadded()
|
||||
return (
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
dtype,
|
||||
prepadded_left,
|
||||
prepadded_right,
|
||||
)
|
||||
|
||||
def run_benchmark(
|
||||
self,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
dtype: Any,
|
||||
prepadded_left: bool,
|
||||
prepadded_right: bool,
|
||||
) -> None:
|
||||
a, b = get_mm_tensors(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
dtype_left=dtype,
|
||||
dtype_right=dtype,
|
||||
)
|
||||
|
||||
print("Benchmarking the following input:")
|
||||
print(f"m={m} k={k} n={n} dtype={dtype}")
|
||||
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
|
||||
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
|
||||
|
||||
with fresh_inductor_cache():
|
||||
|
||||
def mm(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a, b)
|
||||
|
||||
def mm_mat1_prepadded(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a + 1, b)
|
||||
|
||||
def mm_mat2_prepadded(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a, b + 1)
|
||||
|
||||
def mm_mat1_mat2_prepadded(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a + 1, b + 1)
|
||||
|
||||
if prepadded_left and prepadded_right:
|
||||
cf = torch.compile(mm_mat1_mat2_prepadded)
|
||||
elif prepadded_left:
|
||||
cf = torch.compile(mm_mat1_prepadded)
|
||||
elif prepadded_right:
|
||||
cf = torch.compile(mm_mat2_prepadded)
|
||||
else:
|
||||
cf = torch.compile(mm)
|
||||
cf(a, b)
|
||||
torch.compiler.reset()
|
||||
|
||||
def get_random_dim(
|
||||
self, min_power2: int = 1, max_power2: int = 16, p_unaligned: float = 0.25
|
||||
) -> int:
|
||||
aligned = random.choices([True, False], [1 - p_unaligned, p_unaligned])[0]
|
||||
if aligned:
|
||||
return 2 ** random.randint(min_power2, max_power2) # type: ignore[no-any-return]
|
||||
else:
|
||||
# choose a random number between 2^i and 2^(i+1)
|
||||
return self.get_random_between_pow2(min_power2, max_power2) # type: ignore[no-any-return]
|
||||
|
||||
def is_aligned(self, dim: int, align_size: int) -> bool:
|
||||
return dim % align_size == 0
|
||||
|
||||
def get_m_k_n(self, dtype: Any) -> Tuple[int, int, int]:
|
||||
uniform = random.choices([True, False])[0]
|
||||
align_size = get_alignment_size_dtype(dtype)
|
||||
|
||||
# repeat until tensors fit in memory
|
||||
while True:
|
||||
if uniform:
|
||||
m = random.randint(1, 65536)
|
||||
k = random.randint(1, 65536)
|
||||
n = random.randint(1, 65536)
|
||||
else:
|
||||
m = self.get_random_dim()
|
||||
k = self.get_random_dim()
|
||||
n = self.get_random_dim()
|
||||
|
||||
if all(self.is_aligned(dim, align_size) for dim in [m, k, n]):
|
||||
# skip if already aligned
|
||||
continue
|
||||
|
||||
if fits_in_memory(dtype, m, k, n):
|
||||
return (m, k, n)
|
||||
|
||||
def prepadded(self, p_prepadded: float = 0.2) -> bool:
|
||||
# p_prepadded: probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking
|
||||
return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0]
|
||||
|
||||
def get_dtype(self) -> Any:
|
||||
dtype_choices = [torch.float16, torch.bfloat16, torch.float32]
|
||||
return random.choices(dtype_choices)[0]
|
||||
|
||||
def set_precision(self, dtype: Any, p_float32_prec_highest: float = 0.8) -> None:
|
||||
if dtype == torch.float32:
|
||||
precisions = ["high", "highest"]
|
||||
weights = [1 - p_float32_prec_highest, p_float32_prec_highest]
|
||||
precision = random.choices(precisions, weights)[0]
|
||||
else:
|
||||
precision = "high"
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runner = BenchmarkRunnerPadMM()
|
||||
runner.run()
|
@ -1,4 +1,9 @@
|
||||
# mypy: ignore-errors
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from train import AHTrain
|
||||
|
@ -1,216 +0,0 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Tuple
|
||||
|
||||
from tqdm import tqdm # type: ignore[import-untyped]
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
from torch._inductor.fx_passes.pad_mm import get_alignment_size_dtype
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
|
||||
|
||||
# A100: 81920MiB
|
||||
# without a threshold we sometimes run out of memory
|
||||
threshold_memory = 85899345920 / 4
|
||||
|
||||
# probability that a dimension is unaligned
|
||||
p_unaligned = 0.25
|
||||
|
||||
# probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking
|
||||
p_prepadded = 0.2
|
||||
|
||||
# probability that we pick from uniform distribution
|
||||
p_uniform = 0.5
|
||||
|
||||
p_float32_prec_highest = 0.8
|
||||
|
||||
|
||||
def benchmark(
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
dtype: Any,
|
||||
prepadded_left: bool,
|
||||
prepadded_right: bool,
|
||||
) -> None:
|
||||
if transpose_left:
|
||||
a = torch.randn(k, m, dtype=dtype).t()
|
||||
else:
|
||||
a = torch.randn(m, k, dtype=dtype)
|
||||
if transpose_right:
|
||||
b = torch.randn(n, k, dtype=dtype).t()
|
||||
else:
|
||||
b = torch.randn(k, n, dtype=dtype)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
|
||||
def mm(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a, b)
|
||||
|
||||
def mm_mat1_prepadded(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a + 1, b)
|
||||
|
||||
def mm_mat2_prepadded(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a, b + 1)
|
||||
|
||||
def mm_mat1_mat2_prepadded(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a + 1, b + 1)
|
||||
|
||||
if prepadded_left and prepadded_right:
|
||||
cf = torch.compile(mm_mat1_mat2_prepadded)
|
||||
elif prepadded_left:
|
||||
cf = torch.compile(mm_mat1_prepadded)
|
||||
elif prepadded_right:
|
||||
cf = torch.compile(mm_mat2_prepadded)
|
||||
else:
|
||||
cf = torch.compile(mm)
|
||||
cf(a, b)
|
||||
torch.compiler.reset()
|
||||
|
||||
|
||||
def fits_in_memory(dtype: Any, m: int, k: int, n: int) -> Any:
|
||||
return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory
|
||||
|
||||
|
||||
def get_random_dim(min_power2: int = 1, max_power2: int = 16) -> int:
|
||||
aligned = random.choices([True, False], [1 - p_unaligned, p_unaligned])[0]
|
||||
if aligned:
|
||||
return 2 ** random.randint(min_power2, max_power2) # type: ignore[no-any-return]
|
||||
else:
|
||||
# choose a random number between 2^i and 2^(i+1)
|
||||
i = random.randint(min_power2, max_power2 - 1)
|
||||
lower = 2**i + 1
|
||||
upper = 2 ** (i + 1) - 1
|
||||
return random.randint(lower, upper)
|
||||
|
||||
|
||||
def is_aligned(dim: int, align_size: int) -> bool:
|
||||
return dim % align_size == 0
|
||||
|
||||
|
||||
def get_m_k_n(dtype: Any) -> Tuple[int, int, int]:
|
||||
uniform = random.choices([True, False], [0.5, 0.5])[0]
|
||||
align_size = get_alignment_size_dtype(dtype)
|
||||
|
||||
# repeat until tensors fit in memory
|
||||
while True:
|
||||
if uniform:
|
||||
m = random.randint(1, 65536)
|
||||
k = random.randint(1, 65536)
|
||||
n = random.randint(1, 65536)
|
||||
else:
|
||||
m = get_random_dim()
|
||||
k = get_random_dim()
|
||||
n = get_random_dim()
|
||||
|
||||
if all(is_aligned(dim, align_size) for dim in [m, k, n]):
|
||||
# skip if already aligned
|
||||
continue
|
||||
|
||||
if fits_in_memory(dtype, m, k, n):
|
||||
return (m, k, n)
|
||||
|
||||
|
||||
def transpose_tensors() -> Tuple[bool, bool]:
|
||||
p_transpose_both = 0.05
|
||||
transpose_both = random.choices(
|
||||
[True, False], [p_transpose_both, 1 - p_transpose_both]
|
||||
)[0]
|
||||
if transpose_both:
|
||||
return (True, True)
|
||||
transpose_left = (True, False)
|
||||
transpose_right = (False, True)
|
||||
no_transpose = (False, False)
|
||||
return random.choices([transpose_left, transpose_right, no_transpose])[0]
|
||||
|
||||
|
||||
def prepadded() -> bool:
|
||||
return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0]
|
||||
|
||||
|
||||
def get_dtype() -> Any:
|
||||
dtype_choices = [torch.float16, torch.bfloat16, torch.float32]
|
||||
return random.choices(dtype_choices)[0]
|
||||
|
||||
|
||||
def set_precision(dtype: Any) -> None:
|
||||
if dtype == torch.float32:
|
||||
precisions = ["high", "highest"]
|
||||
weights = [1 - p_float32_prec_highest, p_float32_prec_highest]
|
||||
precision = random.choices(precisions, weights)[0]
|
||||
else:
|
||||
precision = "high"
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
|
||||
def main(num_samples: int) -> None:
|
||||
for i in tqdm(range(num_samples)):
|
||||
dtype = get_dtype()
|
||||
set_precision(dtype)
|
||||
m, k, n = get_m_k_n(dtype)
|
||||
|
||||
(transpose_left, transpose_right) = transpose_tensors()
|
||||
prepadded_left = prepadded()
|
||||
prepadded_right = prepadded()
|
||||
|
||||
print("Benchmarking the following input:")
|
||||
print(f"m={m} k={k} n={n} dtype={dtype}")
|
||||
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
|
||||
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
|
||||
|
||||
for i in range(3):
|
||||
benchmark(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
dtype,
|
||||
prepadded_left,
|
||||
prepadded_right,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=None,
|
||||
help="torch.cuda.set_device(device) will be used",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-heuristic",
|
||||
action="store_true",
|
||||
help="Use learned heuristic instead of collecting data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
type=str,
|
||||
default="a100_data.txt",
|
||||
help="Path to file where AutoHeuristic will log results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of samples to collect.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.use_heuristic:
|
||||
torch._inductor.config.autoheuristic_use = "pad_mm"
|
||||
else:
|
||||
torch._inductor.config.autoheuristic_collect = "pad_mm"
|
||||
torch._inductor.config.autoheuristic_log_path = args.o
|
||||
if args.device is not None:
|
||||
torch.cuda.set_device(args.device)
|
||||
random.seed(time.time())
|
||||
main(args.num_samples)
|
Reference in New Issue
Block a user