mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add out_dtype kw argument to optimize_bsr_dense_addmm (#136626)
As in the title. Addresses the task in https://github.com/pytorch/ao/pull/821#issuecomment-2373290266 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136626 Approved by: https://github.com/amjames, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a13282c75
commit
8c840fb921
@ -749,6 +749,7 @@ def bsr_dense_addmm_meta(
|
||||
num_stages=None,
|
||||
sparsity=None,
|
||||
dtype=None,
|
||||
out_dtype=None,
|
||||
_version=0,
|
||||
**extra,
|
||||
):
|
||||
@ -757,15 +758,31 @@ def bsr_dense_addmm_meta(
|
||||
# bsr_dense_addmm_meta functionality.
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
if out_dtype is None:
|
||||
out_dtype = dtype
|
||||
if sparsity is None:
|
||||
sparsity = 0.5
|
||||
if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
|
||||
device_name = torch.cuda.get_device_name()
|
||||
key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
|
||||
if dtype is out_dtype:
|
||||
version_dtype = dtype
|
||||
else:
|
||||
version_dtype = dtype, out_dtype
|
||||
meta = get_meta(
|
||||
"bsr_dense_addmm", key, device_name, version=(_version, dtype, sparsity)
|
||||
"bsr_dense_addmm",
|
||||
key,
|
||||
device_name,
|
||||
version=(_version, version_dtype, sparsity),
|
||||
)
|
||||
if meta is None and sparsity != 0.5:
|
||||
meta = get_meta(
|
||||
"bsr_dense_addmm",
|
||||
key,
|
||||
device_name,
|
||||
version=(_version, version_dtype, 0.5),
|
||||
)
|
||||
if meta is None and dtype is not out_dtype:
|
||||
meta = get_meta(
|
||||
"bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5)
|
||||
)
|
||||
@ -775,8 +792,15 @@ def bsr_dense_addmm_meta(
|
||||
"bsr_dense_addmm",
|
||||
(*key[:2], "*", *key[3:]),
|
||||
device_name,
|
||||
version=(_version, dtype, 0.5),
|
||||
version=(_version, version_dtype, 0.5),
|
||||
)
|
||||
if matching_meta is None and dtype is not out_dtype:
|
||||
matching_meta = get_meta(
|
||||
"bsr_dense_addmm",
|
||||
(*key[:2], "*", *key[3:]),
|
||||
device_name,
|
||||
version=(_version, dtype, 0.5),
|
||||
)
|
||||
for mkey in sorted(matching_meta or {}):
|
||||
meta_ = matching_meta[mkey]
|
||||
n = mkey[2]
|
||||
@ -794,7 +818,7 @@ def bsr_dense_addmm_meta(
|
||||
# message
|
||||
warn_once(
|
||||
"bsr_dense_addmm uses non-optimal triton kernel parameters"
|
||||
f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=}"
|
||||
f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}"
|
||||
)
|
||||
|
||||
SPLIT_N = SPLIT_N or max(N // Ms, 1)
|
||||
@ -1211,7 +1235,8 @@ def bsr_dense_addmm(
|
||||
beta,
|
||||
alpha,
|
||||
sparsity=sparsity,
|
||||
dtype=out.dtype,
|
||||
dtype=dense.dtype,
|
||||
out_dtype=out.dtype,
|
||||
)
|
||||
out_backup = out
|
||||
|
||||
|
@ -643,7 +643,15 @@ def tune_bsr_dense_addmm(
|
||||
# Compute the key of parameters:
|
||||
sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2)
|
||||
dtype = bsr.dtype
|
||||
version = (0, dtype, sparsity)
|
||||
if out is None:
|
||||
out_dtype = dtype
|
||||
else:
|
||||
out_dtype = out.dtype
|
||||
if out_dtype is dtype:
|
||||
version_dtype = dtype
|
||||
else:
|
||||
version_dtype = (dtype, out_dtype)
|
||||
version = (0, version_dtype, sparsity)
|
||||
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
|
||||
|
||||
# For tuning, for an initial state, use parameters from the
|
||||
@ -739,6 +747,7 @@ def optimize_bsr_dense_addmm(
|
||||
use_left_alpha=False,
|
||||
use_right_alpha=False,
|
||||
dtype=torch.float16,
|
||||
out_dtype=None,
|
||||
device="cuda",
|
||||
sparsity=0.5,
|
||||
force=False,
|
||||
@ -755,6 +764,10 @@ def optimize_bsr_dense_addmm(
|
||||
right_alpha = (
|
||||
make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None
|
||||
)
|
||||
if out_dtype is not None:
|
||||
out = dense.new_empty((m, n), dtype=out_dtype)
|
||||
else:
|
||||
out = None
|
||||
tune_bsr_dense_addmm(
|
||||
input,
|
||||
bsr,
|
||||
@ -763,6 +776,7 @@ def optimize_bsr_dense_addmm(
|
||||
alpha=alpha,
|
||||
left_alpha=left_alpha,
|
||||
right_alpha=right_alpha,
|
||||
out=out,
|
||||
store=True,
|
||||
force=force,
|
||||
verbose=verbose,
|
||||
|
Reference in New Issue
Block a user