mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add tune_bsr_dense_addmm as an API to find optimal triton kernel parameters for bsr_dense_addmm (#115499)
As in the title. In addition: - improve the algorithm for finding a minima of operation timings: break the inner loop early when a next minima candidate is found - add tests and fix bugs Pull Request resolved: https://github.com/pytorch/pytorch/pull/115499 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
40dc0580a6
commit
32286512cc
@ -1,11 +1,13 @@
|
||||
__all__ = ["get_meta"]
|
||||
__all__ = ["get_meta", "tune_bsr_dense_addmm"]
|
||||
|
||||
import inspect
|
||||
import itertools
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch.hub import tqdm
|
||||
from torch.testing import make_tensor
|
||||
|
||||
|
||||
@ -120,6 +122,7 @@ def minimize(
|
||||
reference_parameters,
|
||||
step_func,
|
||||
max_step=2,
|
||||
verbose=False,
|
||||
all_values=None,
|
||||
):
|
||||
"""Find a dict of parameters that minimizes the target function using
|
||||
@ -160,10 +163,20 @@ def minimize(
|
||||
if all_values is None:
|
||||
all_values = dict()
|
||||
|
||||
directions = list(range(-max_step, max_step + 1))
|
||||
names = sorted(initial_parameters)
|
||||
all_directions = []
|
||||
for d_tuple in itertools.product(*((directions,) * len(names))):
|
||||
dist = sum(map(abs, d_tuple))
|
||||
if dist > 0 and dist <= max_step:
|
||||
all_directions.append((dist, d_tuple))
|
||||
all_directions.sort()
|
||||
|
||||
try:
|
||||
reference_target = target_func(reference_parameters)
|
||||
except Exception as msg:
|
||||
print(f"{reference_parameters=} lead to failure: {msg}.")
|
||||
if verbose and "out of resource" not in str(msg):
|
||||
print(f"{reference_parameters=} lead to failure: {msg}.")
|
||||
reference_target = None
|
||||
if reference_target is not None:
|
||||
all_values[to_key(reference_parameters)] = reference_target
|
||||
@ -173,36 +186,44 @@ def minimize(
|
||||
initial_target = target_func(parameters)
|
||||
except Exception as msg:
|
||||
if reference_target is None:
|
||||
print(f"{initial_parameters=} lead to failure: {msg}. Optimization failed!")
|
||||
return {}, -1, None
|
||||
print(
|
||||
f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters."
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"{initial_parameters=} lead to failure: {msg}. Optimization failed!"
|
||||
)
|
||||
return {}, -1, -1, f"{msg}"
|
||||
if verbose and "out of resource" not in str(msg):
|
||||
print(
|
||||
f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters."
|
||||
)
|
||||
parameters = reference_parameters
|
||||
initial_target = reference_target
|
||||
|
||||
if reference_target is None:
|
||||
print("Using initial parameters instead of reference parameters.")
|
||||
if verbose:
|
||||
print("Using initial parameters instead of reference parameters.")
|
||||
reference_target = initial_target
|
||||
|
||||
initial_key = to_key(parameters)
|
||||
all_values[initial_key] = initial_target
|
||||
|
||||
minimal_target = all_values[initial_key] = initial_target
|
||||
pbar = tqdm(
|
||||
total=len(all_directions),
|
||||
desc="Tuning...",
|
||||
disable=not verbose,
|
||||
ncols=75,
|
||||
)
|
||||
while True:
|
||||
current_key = to_key(parameters)
|
||||
minimizer_target = all_values[current_key]
|
||||
minimizer_key = current_key
|
||||
new_minimizer = False
|
||||
for name in parameters:
|
||||
value = parameters[name]
|
||||
for direction in range(-max_step, max_step + 1):
|
||||
for i, (_, d_tuple) in enumerate(all_directions):
|
||||
pbar.update(1)
|
||||
next_parameters = parameters.copy()
|
||||
for name, direction in zip(names, d_tuple):
|
||||
value = next_parameters[name]
|
||||
if direction == 0:
|
||||
continue
|
||||
next_value = step_func(name, value, direction, parameters)
|
||||
if next_value == value:
|
||||
continue
|
||||
next_parameters = parameters.copy()
|
||||
break
|
||||
next_parameters[name] = next_value
|
||||
else:
|
||||
next_key = to_key(next_parameters)
|
||||
if next_key in all_values:
|
||||
continue
|
||||
@ -210,38 +231,41 @@ def minimize(
|
||||
next_target = target_func(next_parameters)
|
||||
except Exception as msg:
|
||||
all_values[next_key] = str(msg)
|
||||
print(f"{next_parameters=} lead to failure: {msg}. Skipping.")
|
||||
if verbose and "out of resource" not in str(msg):
|
||||
print(f"{next_parameters=} lead to failure: {msg}. Skipping.")
|
||||
continue
|
||||
all_values[next_key] = next_target
|
||||
if next_target < minimizer_target:
|
||||
minimizer_target = next_target
|
||||
minimizer_key = next_key
|
||||
new_minimizer = True
|
||||
if new_minimizer:
|
||||
parameters = from_key(minimizer_key, parameters)
|
||||
|
||||
if next_target < minimal_target:
|
||||
minimal_target = next_target
|
||||
parameters = next_parameters
|
||||
pbar.total += i + 1
|
||||
break
|
||||
else:
|
||||
# ensure stable minimizer:
|
||||
minimizer_keys = {
|
||||
k
|
||||
for k, v in all_values.items()
|
||||
if isinstance(v, float) and abs(1 - v / minimizer_target) < 0.001
|
||||
if isinstance(v, float) and abs(1 - v / minimal_target) < 0.001
|
||||
}
|
||||
minimizer_key = (
|
||||
initial_key if initial_key in minimizer_keys else min(minimizer_keys)
|
||||
)
|
||||
minimizer_target = all_values[minimizer_key]
|
||||
parameters = from_key(minimizer_key, parameters)
|
||||
speedup_incr = (1 - minimizer_target / reference_target) * 100
|
||||
speedup_incr = (1 - minimal_target / reference_target) * 100
|
||||
if speedup_incr < 0:
|
||||
print(
|
||||
f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters."
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters."
|
||||
)
|
||||
return minimize(
|
||||
target_func,
|
||||
reference_parameters,
|
||||
reference_parameters,
|
||||
step_func,
|
||||
max_step=max_step,
|
||||
verbose=verbose,
|
||||
all_values=all_values,
|
||||
)
|
||||
sensitivity = []
|
||||
@ -262,7 +286,7 @@ def minimize(
|
||||
if next_target is None or isinstance(next_target, str):
|
||||
rel_diffs.append(0)
|
||||
continue
|
||||
rel_diff = (next_target / minimizer_target - 1) * 100
|
||||
rel_diff = (next_target / minimal_target - 1) * 100
|
||||
rel_diffs.append(rel_diff)
|
||||
sensitivity.append((max(rel_diffs), rel_diffs, name))
|
||||
|
||||
@ -278,7 +302,7 @@ def minimize(
|
||||
f"{name}={parameters[name]} ({left_diffs}...{right_diffs} %)"
|
||||
)
|
||||
sensitivity_message = ", ".join(sensitivity_message)
|
||||
return parameters, speedup_incr, minimizer_target, sensitivity_message
|
||||
return parameters, speedup_incr, minimal_target, sensitivity_message
|
||||
|
||||
|
||||
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||
@ -420,67 +444,80 @@ def optimize_scatter_mm(
|
||||
)
|
||||
|
||||
|
||||
def optimize_bsr_dense_addmm(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
bm,
|
||||
bk,
|
||||
def tune_bsr_dense_addmm(
|
||||
input,
|
||||
bsr,
|
||||
dense,
|
||||
*,
|
||||
beta=1,
|
||||
alpha=1,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
sparsity=0.5,
|
||||
out=None,
|
||||
store=False,
|
||||
verbose=False,
|
||||
force=False,
|
||||
):
|
||||
"""Tune bsr_dense_addmm kernel parameters against the given inputs.
|
||||
|
||||
When store is True, the tuning results will be stored in the
|
||||
database of kernel parameters.
|
||||
"""
|
||||
import triton
|
||||
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm
|
||||
|
||||
key = (m, k, n, bm, bk, beta == 0, beta == 1, alpha == 1)
|
||||
version = (0, dtype, sparsity)
|
||||
N = dense.shape[-1]
|
||||
values = bsr.values()
|
||||
crow_indices = bsr.crow_indices()
|
||||
batch_ndim = crow_indices.dim() - 1
|
||||
M, K = bsr.shape[batch_ndim : batch_ndim + 2]
|
||||
BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3]
|
||||
|
||||
# Reference parameters is a set of parameters that leads to a
|
||||
# successful kernel call and the corresponding timing is used as a
|
||||
# reference for computing speedups. Avoid changing the reference
|
||||
# parameters when possible.
|
||||
reference_meta = dict(
|
||||
GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(n // bm, 1)
|
||||
GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(N // BM, 1)
|
||||
)
|
||||
|
||||
initial_meta = get_meta("bsr_dense_addmm", key, version=version, exact=True)
|
||||
# Compute the key of parameters:
|
||||
sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2)
|
||||
dtype = bsr.dtype
|
||||
version = (0, dtype, sparsity)
|
||||
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
|
||||
|
||||
# For tuning, for an initial state, use parameters from the
|
||||
# database if available, otherwise, use the reference parameters.
|
||||
initial_meta = get_meta("bsr_dense_addmm", key, version=version, exact=True)
|
||||
if initial_meta is None:
|
||||
may_skip_update = False
|
||||
initial_meta = get_meta(
|
||||
"bsr_dense_addmm", key, version=(0, dtype, 0.5), exact=True
|
||||
)
|
||||
if initial_meta is None:
|
||||
initial_meta = reference_meta
|
||||
elif not force:
|
||||
return
|
||||
return initial_meta
|
||||
else:
|
||||
may_skip_update = True
|
||||
|
||||
print(f"{key, initial_meta=}")
|
||||
|
||||
torch.manual_seed(0)
|
||||
bsr = create_blocked_tensor(
|
||||
0, m, k, (bm, bk), sparsity, dtype, device
|
||||
).to_sparse_bsr((bm, bk))
|
||||
dense = make_tensor(k, n, dtype=dtype, device=device)
|
||||
input = make_tensor(m, n, dtype=dtype, device=device)
|
||||
|
||||
def bench(meta, bsr=bsr, dense=dense):
|
||||
# The target function that is minimized in the tuning process:
|
||||
def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out):
|
||||
def test_func():
|
||||
return bsr_dense_addmm(input, bsr, dense, beta=beta, alpha=alpha, meta=meta)
|
||||
return bsr_dense_addmm(
|
||||
input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out
|
||||
)
|
||||
|
||||
ms_min = triton.testing.do_bench(
|
||||
test_func, warmup=500, rep=100, fast_flush=False
|
||||
)
|
||||
return triton.testing.do_bench(test_func, warmup=500, rep=100, fast_flush=False)
|
||||
|
||||
return ms_min
|
||||
|
||||
def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=bk):
|
||||
# The step function that increments a specified meta parameter:
|
||||
def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK):
|
||||
# return next value in positive or negative direction, or
|
||||
# input value if the step will result an invalid
|
||||
# value. The input value is assumed to be valid.
|
||||
is_log = name in {"SPLIT_N", "num_warps"}
|
||||
min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name]
|
||||
max_value = dict(SPLIT_N=max(n // bm, 1)).get(name)
|
||||
max_value = dict(SPLIT_N=max(N // BM, 1)).get(name)
|
||||
value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
|
||||
if is_log:
|
||||
next_value = (
|
||||
@ -494,28 +531,58 @@ def optimize_bsr_dense_addmm(
|
||||
next_value = max(next_value, min_value)
|
||||
if max_value is not None:
|
||||
next_value = min(next_value, max_value)
|
||||
if name == "SPLIT_N" and n % next_value != 0:
|
||||
if name == "SPLIT_N" and N % next_value != 0:
|
||||
return value
|
||||
return next_value
|
||||
|
||||
# Tune:
|
||||
meta, speedup, timing, sensitivity_message = minimize(
|
||||
bench, initial_meta, reference_meta, step_meta_parameter, max_step=2
|
||||
bench,
|
||||
initial_meta,
|
||||
reference_meta,
|
||||
step_meta_parameter,
|
||||
max_step=2,
|
||||
verbose=verbose,
|
||||
)
|
||||
if verbose:
|
||||
print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms")
|
||||
|
||||
if initial_meta is not reference_meta and initial_meta == meta and not force:
|
||||
return
|
||||
print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms")
|
||||
if store and not (
|
||||
may_skip_update and meta == initial_meta and initial_meta is not reference_meta
|
||||
):
|
||||
device_name = torch.cuda.get_device_name()
|
||||
update(
|
||||
"bsr_dense_addmm",
|
||||
device_name,
|
||||
version,
|
||||
key,
|
||||
tuple(meta[k] for k in sorted(meta)),
|
||||
)
|
||||
|
||||
if speedup < 0:
|
||||
return
|
||||
device_name = torch.cuda.get_device_name()
|
||||
return meta
|
||||
|
||||
update(
|
||||
"bsr_dense_addmm",
|
||||
device_name,
|
||||
version,
|
||||
key,
|
||||
tuple(meta[k] for k in sorted(meta)),
|
||||
|
||||
def optimize_bsr_dense_addmm(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
bm,
|
||||
bk,
|
||||
beta=1,
|
||||
alpha=1,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
sparsity=0.5,
|
||||
force=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
bsr = create_blocked_tensor(
|
||||
0, m, k, (bm, bk), sparsity, dtype, device
|
||||
).to_sparse_bsr((bm, bk))
|
||||
dense = make_tensor(k, n, dtype=dtype, device=device)
|
||||
input = make_tensor(m, n, dtype=dtype, device=device)
|
||||
tune_bsr_dense_addmm(
|
||||
input, bsr, dense, beta=beta, alpha=alpha, store=True, force=force
|
||||
)
|
||||
|
||||
|
||||
@ -3851,5 +3918,5 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
for op in ["scatter_mm", "bsr_dense_addmm"]:
|
||||
for op in ["bsr_dense_addmm"]:
|
||||
main(op=op, force=False, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user