Add tune_bsr_dense_addmm as an API to find optimal triton kernel parameters for bsr_dense_addmm (#115499)

As in the title.

In addition:
- improve the algorithm for finding a minima of operation timings: break the inner loop early when a next minima candidate is found
- add tests and fix bugs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115499
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-12-10 11:02:38 +00:00
committed by PyTorch MergeBot
parent 40dc0580a6
commit 32286512cc
3 changed files with 198 additions and 85 deletions

View File

@ -1,11 +1,13 @@
__all__ = ["get_meta"]
__all__ = ["get_meta", "tune_bsr_dense_addmm"]
import inspect
import itertools
import re
import warnings
from typing import Any, Dict
import torch
from torch.hub import tqdm
from torch.testing import make_tensor
@ -120,6 +122,7 @@ def minimize(
reference_parameters,
step_func,
max_step=2,
verbose=False,
all_values=None,
):
"""Find a dict of parameters that minimizes the target function using
@ -160,10 +163,20 @@ def minimize(
if all_values is None:
all_values = dict()
directions = list(range(-max_step, max_step + 1))
names = sorted(initial_parameters)
all_directions = []
for d_tuple in itertools.product(*((directions,) * len(names))):
dist = sum(map(abs, d_tuple))
if dist > 0 and dist <= max_step:
all_directions.append((dist, d_tuple))
all_directions.sort()
try:
reference_target = target_func(reference_parameters)
except Exception as msg:
print(f"{reference_parameters=} lead to failure: {msg}.")
if verbose and "out of resource" not in str(msg):
print(f"{reference_parameters=} lead to failure: {msg}.")
reference_target = None
if reference_target is not None:
all_values[to_key(reference_parameters)] = reference_target
@ -173,36 +186,44 @@ def minimize(
initial_target = target_func(parameters)
except Exception as msg:
if reference_target is None:
print(f"{initial_parameters=} lead to failure: {msg}. Optimization failed!")
return {}, -1, None
print(
f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters."
)
if verbose:
print(
f"{initial_parameters=} lead to failure: {msg}. Optimization failed!"
)
return {}, -1, -1, f"{msg}"
if verbose and "out of resource" not in str(msg):
print(
f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters."
)
parameters = reference_parameters
initial_target = reference_target
if reference_target is None:
print("Using initial parameters instead of reference parameters.")
if verbose:
print("Using initial parameters instead of reference parameters.")
reference_target = initial_target
initial_key = to_key(parameters)
all_values[initial_key] = initial_target
minimal_target = all_values[initial_key] = initial_target
pbar = tqdm(
total=len(all_directions),
desc="Tuning...",
disable=not verbose,
ncols=75,
)
while True:
current_key = to_key(parameters)
minimizer_target = all_values[current_key]
minimizer_key = current_key
new_minimizer = False
for name in parameters:
value = parameters[name]
for direction in range(-max_step, max_step + 1):
for i, (_, d_tuple) in enumerate(all_directions):
pbar.update(1)
next_parameters = parameters.copy()
for name, direction in zip(names, d_tuple):
value = next_parameters[name]
if direction == 0:
continue
next_value = step_func(name, value, direction, parameters)
if next_value == value:
continue
next_parameters = parameters.copy()
break
next_parameters[name] = next_value
else:
next_key = to_key(next_parameters)
if next_key in all_values:
continue
@ -210,38 +231,41 @@ def minimize(
next_target = target_func(next_parameters)
except Exception as msg:
all_values[next_key] = str(msg)
print(f"{next_parameters=} lead to failure: {msg}. Skipping.")
if verbose and "out of resource" not in str(msg):
print(f"{next_parameters=} lead to failure: {msg}. Skipping.")
continue
all_values[next_key] = next_target
if next_target < minimizer_target:
minimizer_target = next_target
minimizer_key = next_key
new_minimizer = True
if new_minimizer:
parameters = from_key(minimizer_key, parameters)
if next_target < minimal_target:
minimal_target = next_target
parameters = next_parameters
pbar.total += i + 1
break
else:
# ensure stable minimizer:
minimizer_keys = {
k
for k, v in all_values.items()
if isinstance(v, float) and abs(1 - v / minimizer_target) < 0.001
if isinstance(v, float) and abs(1 - v / minimal_target) < 0.001
}
minimizer_key = (
initial_key if initial_key in minimizer_keys else min(minimizer_keys)
)
minimizer_target = all_values[minimizer_key]
parameters = from_key(minimizer_key, parameters)
speedup_incr = (1 - minimizer_target / reference_target) * 100
speedup_incr = (1 - minimal_target / reference_target) * 100
if speedup_incr < 0:
print(
f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters."
)
if verbose:
print(
f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters."
)
return minimize(
target_func,
reference_parameters,
reference_parameters,
step_func,
max_step=max_step,
verbose=verbose,
all_values=all_values,
)
sensitivity = []
@ -262,7 +286,7 @@ def minimize(
if next_target is None or isinstance(next_target, str):
rel_diffs.append(0)
continue
rel_diff = (next_target / minimizer_target - 1) * 100
rel_diff = (next_target / minimal_target - 1) * 100
rel_diffs.append(rel_diff)
sensitivity.append((max(rel_diffs), rel_diffs, name))
@ -278,7 +302,7 @@ def minimize(
f"{name}={parameters[name]} ({left_diffs}...{right_diffs} %)"
)
sensitivity_message = ", ".join(sensitivity_message)
return parameters, speedup_incr, minimizer_target, sensitivity_message
return parameters, speedup_incr, minimal_target, sensitivity_message
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
@ -420,67 +444,80 @@ def optimize_scatter_mm(
)
def optimize_bsr_dense_addmm(
m,
k,
n,
bm,
bk,
def tune_bsr_dense_addmm(
input,
bsr,
dense,
*,
beta=1,
alpha=1,
dtype=torch.float16,
device="cuda",
sparsity=0.5,
out=None,
store=False,
verbose=False,
force=False,
):
"""Tune bsr_dense_addmm kernel parameters against the given inputs.
When store is True, the tuning results will be stored in the
database of kernel parameters.
"""
import triton
from torch.sparse._triton_ops import bsr_dense_addmm
key = (m, k, n, bm, bk, beta == 0, beta == 1, alpha == 1)
version = (0, dtype, sparsity)
N = dense.shape[-1]
values = bsr.values()
crow_indices = bsr.crow_indices()
batch_ndim = crow_indices.dim() - 1
M, K = bsr.shape[batch_ndim : batch_ndim + 2]
BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3]
# Reference parameters is a set of parameters that leads to a
# successful kernel call and the corresponding timing is used as a
# reference for computing speedups. Avoid changing the reference
# parameters when possible.
reference_meta = dict(
GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(n // bm, 1)
GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(N // BM, 1)
)
initial_meta = get_meta("bsr_dense_addmm", key, version=version, exact=True)
# Compute the key of parameters:
sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2)
dtype = bsr.dtype
version = (0, dtype, sparsity)
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
# For tuning, for an initial state, use parameters from the
# database if available, otherwise, use the reference parameters.
initial_meta = get_meta("bsr_dense_addmm", key, version=version, exact=True)
if initial_meta is None:
may_skip_update = False
initial_meta = get_meta(
"bsr_dense_addmm", key, version=(0, dtype, 0.5), exact=True
)
if initial_meta is None:
initial_meta = reference_meta
elif not force:
return
return initial_meta
else:
may_skip_update = True
print(f"{key, initial_meta=}")
torch.manual_seed(0)
bsr = create_blocked_tensor(
0, m, k, (bm, bk), sparsity, dtype, device
).to_sparse_bsr((bm, bk))
dense = make_tensor(k, n, dtype=dtype, device=device)
input = make_tensor(m, n, dtype=dtype, device=device)
def bench(meta, bsr=bsr, dense=dense):
# The target function that is minimized in the tuning process:
def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out):
def test_func():
return bsr_dense_addmm(input, bsr, dense, beta=beta, alpha=alpha, meta=meta)
return bsr_dense_addmm(
input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out
)
ms_min = triton.testing.do_bench(
test_func, warmup=500, rep=100, fast_flush=False
)
return triton.testing.do_bench(test_func, warmup=500, rep=100, fast_flush=False)
return ms_min
def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=bk):
# The step function that increments a specified meta parameter:
def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK):
# return next value in positive or negative direction, or
# input value if the step will result an invalid
# value. The input value is assumed to be valid.
is_log = name in {"SPLIT_N", "num_warps"}
min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name]
max_value = dict(SPLIT_N=max(n // bm, 1)).get(name)
max_value = dict(SPLIT_N=max(N // BM, 1)).get(name)
value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
if is_log:
next_value = (
@ -494,28 +531,58 @@ def optimize_bsr_dense_addmm(
next_value = max(next_value, min_value)
if max_value is not None:
next_value = min(next_value, max_value)
if name == "SPLIT_N" and n % next_value != 0:
if name == "SPLIT_N" and N % next_value != 0:
return value
return next_value
# Tune:
meta, speedup, timing, sensitivity_message = minimize(
bench, initial_meta, reference_meta, step_meta_parameter, max_step=2
bench,
initial_meta,
reference_meta,
step_meta_parameter,
max_step=2,
verbose=verbose,
)
if verbose:
print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms")
if initial_meta is not reference_meta and initial_meta == meta and not force:
return
print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms")
if store and not (
may_skip_update and meta == initial_meta and initial_meta is not reference_meta
):
device_name = torch.cuda.get_device_name()
update(
"bsr_dense_addmm",
device_name,
version,
key,
tuple(meta[k] for k in sorted(meta)),
)
if speedup < 0:
return
device_name = torch.cuda.get_device_name()
return meta
update(
"bsr_dense_addmm",
device_name,
version,
key,
tuple(meta[k] for k in sorted(meta)),
def optimize_bsr_dense_addmm(
m,
k,
n,
bm,
bk,
beta=1,
alpha=1,
dtype=torch.float16,
device="cuda",
sparsity=0.5,
force=False,
):
torch.manual_seed(0)
bsr = create_blocked_tensor(
0, m, k, (bm, bk), sparsity, dtype, device
).to_sparse_bsr((bm, bk))
dense = make_tensor(k, n, dtype=dtype, device=device)
input = make_tensor(m, n, dtype=dtype, device=device)
tune_bsr_dense_addmm(
input, bsr, dense, beta=beta, alpha=alpha, store=True, force=force
)
@ -3851,5 +3918,5 @@ _operation_device_version_data: Dict[Any, Dict] = {
if __name__ == "__main__":
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
for op in ["scatter_mm", "bsr_dense_addmm"]:
for op in ["bsr_dense_addmm"]:
main(op=op, force=False, dtype=dtype)