Add out_dtype kw argument to optimize_bsr_dense_addmm (#136626)

As in the title.

Addresses the task in https://github.com/pytorch/ao/pull/821#issuecomment-2373290266

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136626
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-10-21 13:43:41 +03:00
committed by PyTorch MergeBot
parent 5a13282c75
commit 8c840fb921
3 changed files with 87 additions and 14 deletions

View File

@ -749,6 +749,7 @@ def bsr_dense_addmm_meta(
num_stages=None,
sparsity=None,
dtype=None,
out_dtype=None,
_version=0,
**extra,
):
@ -757,15 +758,31 @@ def bsr_dense_addmm_meta(
# bsr_dense_addmm_meta functionality.
if dtype is None:
dtype = torch.float16
if out_dtype is None:
out_dtype = dtype
if sparsity is None:
sparsity = 0.5
if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
device_name = torch.cuda.get_device_name()
key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
if dtype is out_dtype:
version_dtype = dtype
else:
version_dtype = dtype, out_dtype
meta = get_meta(
"bsr_dense_addmm", key, device_name, version=(_version, dtype, sparsity)
"bsr_dense_addmm",
key,
device_name,
version=(_version, version_dtype, sparsity),
)
if meta is None and sparsity != 0.5:
meta = get_meta(
"bsr_dense_addmm",
key,
device_name,
version=(_version, version_dtype, 0.5),
)
if meta is None and dtype is not out_dtype:
meta = get_meta(
"bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5)
)
@ -775,8 +792,15 @@ def bsr_dense_addmm_meta(
"bsr_dense_addmm",
(*key[:2], "*", *key[3:]),
device_name,
version=(_version, dtype, 0.5),
version=(_version, version_dtype, 0.5),
)
if matching_meta is None and dtype is not out_dtype:
matching_meta = get_meta(
"bsr_dense_addmm",
(*key[:2], "*", *key[3:]),
device_name,
version=(_version, dtype, 0.5),
)
for mkey in sorted(matching_meta or {}):
meta_ = matching_meta[mkey]
n = mkey[2]
@ -794,7 +818,7 @@ def bsr_dense_addmm_meta(
# message
warn_once(
"bsr_dense_addmm uses non-optimal triton kernel parameters"
f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=}"
f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}"
)
SPLIT_N = SPLIT_N or max(N // Ms, 1)
@ -1211,7 +1235,8 @@ def bsr_dense_addmm(
beta,
alpha,
sparsity=sparsity,
dtype=out.dtype,
dtype=dense.dtype,
out_dtype=out.dtype,
)
out_backup = out

View File

@ -643,7 +643,15 @@ def tune_bsr_dense_addmm(
# Compute the key of parameters:
sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2)
dtype = bsr.dtype
version = (0, dtype, sparsity)
if out is None:
out_dtype = dtype
else:
out_dtype = out.dtype
if out_dtype is dtype:
version_dtype = dtype
else:
version_dtype = (dtype, out_dtype)
version = (0, version_dtype, sparsity)
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
# For tuning, for an initial state, use parameters from the
@ -739,6 +747,7 @@ def optimize_bsr_dense_addmm(
use_left_alpha=False,
use_right_alpha=False,
dtype=torch.float16,
out_dtype=None,
device="cuda",
sparsity=0.5,
force=False,
@ -755,6 +764,10 @@ def optimize_bsr_dense_addmm(
right_alpha = (
make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None
)
if out_dtype is not None:
out = dense.new_empty((m, n), dtype=out_dtype)
else:
out = None
tune_bsr_dense_addmm(
input,
bsr,
@ -763,6 +776,7 @@ def optimize_bsr_dense_addmm(
alpha=alpha,
left_alpha=left_alpha,
right_alpha=right_alpha,
out=out,
store=True,
force=force,
verbose=verbose,