Compare commits

...

7 Commits

Author SHA1 Message Date
764fea0bc7 fix and disable upcast fp32 2025-07-11 12:10:43 -04:00
cb4a36b6f0 enable native matmul = True 2025-07-08 12:28:31 -04:00
b9c496a0ed lint and fix 2025-07-08 00:01:46 -04:00
cb8aa1d37f fix 2025-07-07 18:48:26 -04:00
15a3bcc968 add heuristics 2025-07-07 17:17:59 -04:00
f758d3d518 add ops.dot codegen 2025-07-07 15:24:35 -04:00
499fc5bd4f add ops.dot 2025-07-07 14:22:52 -04:00
10 changed files with 481 additions and 16 deletions

View File

@ -0,0 +1,128 @@
# Owner(s): ["module: inductor"]
import os
import sys
import torch
from torch._dynamo.utils import same
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.testing import rand_strided
from torch._inductor import config as inductor_config
aten = torch.ops.aten
@inductor_config.patch(
{"triton.enable_native_matmul": True, "coordinate_descent_tuning": False}
)
class TestTritonDotReduction(TestCase):
def test_matmul(self):
def f(x, y):
z = x @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
compiled = torch.compile(f)
actual = compiled(x, y)
expect = f(x, y)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")
code = run_and_get_triton_code(compiled, x, y)
lines = [line for line in code.split("\n") if "tl.dot" in line]
assert len(lines) == 1
@inductor_config.patch({"triton.codegen_upcast_to_fp32": False})
def test_matmul_fp16(self):
def f(x, y):
z = x @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), dtype=torch.float16, device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), dtype=torch.float16, device=GPU_TYPE)
compiled = torch.compile(f)
actual = compiled(x, y)
expect = f(x, y)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")
code = run_and_get_triton_code(compiled, x, y)
lines = [line for line in code.split("\n") if "tl.dot" in line]
assert len(lines) == 1
def test_mm_add(self):
def f(x, y, z, w):
return x @ y + z @ w
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
w = rand_strided((M, K), (K, 1), device=GPU_TYPE)
z = rand_strided((K, N), (N, 1), device=GPU_TYPE)
compiled = torch.compile(f)
actual = compiled(x, y, z, w)
expect = f(x, y, z, w)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")
code = run_and_get_triton_code(compiled, x, y, z, w)
lines = [line for line in code.split("\n") if "tl.dot" in line]
assert len(lines) == 2
def test_mm_complex(self):
def f(x, y, z, w):
return x[z] @ y + w + 3
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
z = torch.randint(M, (M, K), dtype=torch.long, device=GPU_TYPE)
w = rand_strided((M, N), (N, 1), device=GPU_TYPE)
compiled = torch.compile(f)
actual = compiled(x, y, z, w)
expect = f(x, y, z, w)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")
code = run_and_get_triton_code(compiled, x, y, z, w)
lines = [line for line in code.split("\n") if "tl.dot" in line]
assert len(lines) == 1
def test_batchmatmul(self):
def f(x, y):
z = torch.bmm(x, y)
return z
B, M, K, N = 256, 128, 128, 128
x = rand_strided((B, M, K), (M * K, K, 1), device=GPU_TYPE)
y = rand_strided((B, K, N), (K * N, N, 1), device=GPU_TYPE)
compiled = torch.compile(f)
actual = compiled(x, y)
expect = f(x, y)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")
code = run_and_get_triton_code(compiled, x, y)
lines = [line for line in code.split("\n") if "tl.dot" in line]
assert len(lines) == 1
if HAS_GPU:
torch.set_default_device(GPU_TYPE)
if __name__ == "__main__":
if HAS_GPU:
run_tests()

View File

@ -1015,6 +1015,11 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
)
def dot(self, x: OpVarT, y: OpVarT) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: dot only implemented for Triton backend"
)
def inline_asm_elementwise(
self,
*inputs: OpVarT,

View File

@ -1426,7 +1426,7 @@ class SIMDScheduling(BaseScheduling):
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
node_schedule = kernel_features.node_schedule
tiling, tiling_score = self.get_tiling_and_scores(
node_schedule,
kernel_features.numel,
@ -2317,6 +2317,19 @@ class SIMDScheduling(BaseScheduling):
# Tiled reductions are gated by a config flag.
default_tiling = cls.create_tiling([numel], [reduction_numel])
# Force tiling compatible with matmul dimensions
# when natively generating matmul without template calls.
if torch._inductor.config.triton.enable_native_matmul :
for node in EnableReduction.filter(node_schedule):
# A[M,K] @ B[K,N]
# force tiling to be {'y':M, 'x':N, 'r0_':K}
if node.node.get_reduction_type() == "dot" :
node_ranges = node.get_ranges()
range_y_x = node_ranges[0] #(M,N)
range_r = node_ranges[1] #(K)
tiling = cls.create_tiling(range_y_x, range_r)
return tiling, None
# # TODO: enable by default
if (
torch._inductor.config.triton.coalesce_tiling_analysis

View File

@ -1037,6 +1037,41 @@ class TritonOverrides(OpOverrides):
def where(a, b, c):
return f"tl.where({a}, {b}, {c})"
@staticmethod
def dot(a, b):
dense_sizes = V.kernel.dense_size_list()
allow_tf32 = torch.backends.cuda.matmul.allow_tf32
assert torch._inductor.config.triton.enable_native_matmul
# mm case
if len(dense_sizes) == 3:
Y = dense_sizes[0]
X = dense_sizes[1]
R = dense_sizes[2]
# a = (1,YBLOCK,RBLOCK)
# b = (XBLOCK,1,RBLOCK)
a_squeezed = triton_reshape(str(a), [1, Y, R], [Y, R]) # (Y,R)
b_squeezed = triton_reshape(str(b), [X, 1, R], [X, R]) # (X,R)
b_transposed = f"tl.trans({b_squeezed})" # (R,X)
return f"tl.dot({a_squeezed}, {b_transposed}, allow_tf32={allow_tf32})" # (Y,X)
elif len(dense_sizes) == 4:
Y = dense_sizes[1]
X = dense_sizes[2]
R = dense_sizes[3]
# a = (ZBLOCK,YBLOCK,1,RBLOCK)
# b = (ZBLOCK,1,XBLOCK,RBLOCK)
# Note that, autotuner config will always ensure ZBLOCK=1
a_squeezed = triton_reshape(str(a), [1, Y, 1, R], [Y, R])
b_squeezed = triton_reshape(str(b), [1, 1, X, R], [X, R])
b_transposed = f"tl.trans({b_squeezed})" # (R,X)
return f"tl.dot({a_squeezed}, {b_transposed}, allow_tf32={allow_tf32})" # (Y,X)
else:
raise NotImplementedError("tl.dot can only do mm and bmm")
@staticmethod
def inline_asm_elementwise(
*inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1
@ -2049,7 +2084,45 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
expand_str = None
index_str = self.index_to_str(index)
if isinstance(index, sympy.Integer):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
if (
self.inside_reduction
and torch._inductor.config.triton.enable_native_matmul
and self.current_node.node.get_reduction_type() == "dot"
):
# Consider the following code:
#
# tmp0 = tl.load(in_ptr0 + y0)
# tmp1 = tl.full([YBLOCK, XBLOCK, R0_BLOCK], 128, tl.int32)
# tmp2 = tmp0 + tmp1
# tmp3 = tmp0 < 0
# tmp4 = tl.where(tmp3, tmp2, tmp0)
# x1 = xindex
# acc = tl.full([YBLOCK, XBLOCK], 0, tl.float32)
#
# for r_offset in range(0, r0_numel, R0_BLOCK):
# r = r_offset + r0_base
# a = tl.load(in_ptr2 + (x1 + 128 * r))
# b = tl.load(in_ptr1 + (r + 128 * tmp4))
# dot = tl.dot(
# tl.reshape(b, [YBLOCK, R0_BLOCK]),
# tl.trans(tl.reshape(a, [XBLOCK, R0_BLOCK]))
# )
#
# This handles an indirect matmul: A[y, :] @ B
# To deal with negative indices in the indirection,
# the generated code adds a constant (128) to the index.
#
# However, creating a dense constant of shape [YBLOCK, XBLOCK, R0_BLOCK]
# breaks axis alignment for tl.dot, which expects inputs shaped (Y, R) x (R, X).
#
# Instead of broadcasting a dense constant, we use a size-1 scalar constant
# to preserve the correct dependency on the [Y, R] axes for tl.dot.
expand_str = str([1] * len(self.dense_size_list()))
else:
expand_str = (
f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
)
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
if self.fixed_config and not self._has_constant_xmask():
mask_vars = OrderedSet(["xmask"])
@ -2487,13 +2560,31 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
masks.append(self._load_mask)
reduction_range_prefix = self.range_trees[-1].prefix[0]
# When we do native matmtul codegen,
# we don't want to keep the R0_BLOCK/R1_BLOCK in the accumulator.
# so instead of naively calling dense_size_str(), we filter out
# reduction block from accumulator.
# In bmm (Z,Y,R)x(Z,R,X) case, we also remove z dimension from accumulator
# because 3d (Z,Y,X) tl.dot is somehow slower than 2d tl.dot.
# Instead, we force ZBLOCK to be always 1 during autotune.
dense_size_str: str
if (
torch._inductor.config.triton.enable_native_matmul
and reduction_type == "dot"
):
dense_sizes = self.dense_size_list()
assert len(dense_sizes) >= 3
xy_sizes_only = [size for size in dense_sizes if "X" in size or "Y" in size]
dense_size_str = f"[{', '.join(xy_sizes_only)}]"
else:
dense_size_str = self.dense_size_str()
# Say we have
# tmp0 = ops.constant(1, torch.int64)
# tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0)
# tmp0 in the triton code is either a scalar, or single-element tensor
# so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1
# To avoid this, we broadcast to the expected shape first.
dense_size_str = self.dense_size_str()
value = self._map_tuple_or_scalar(
lambda v: self.cse.generate(
self.compute,
@ -2522,6 +2613,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
value = self.reduction_resize(
f"{module}.{reduction_type}2({value}, {dim})"
)
elif reduction_type == "dot":
is_bmm = len(self.dense_size_list()) == 4
if is_bmm:
value = f"{value}[None,:,:,None]" # (Y,X) to (Z=1,Y,X,R=1)
else:
value = f"{value}[:,:,None]" # (Y,X) to (Y,X,R=1)
else:
value = self.reduction_resize(
f"{module}.{reduction_type}({value}, {dim})"
@ -2587,6 +2684,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
pass
elif isinstance(value, tuple):
masked_value = [_mask_value(v, d) for v, d in zip(value, default)]
elif reduction_type == "dot":
# We don't need where condition in native matmul.
masked_value = self.cse.generate(self.compute, value, dtype=value.dtype)
else:
masked_value = _mask_value(value, default)
@ -2640,9 +2740,20 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
default = self._map_tuple_or_scalar(constant_repr, default)
if not isinstance(default, tuple):
self.body.writeline(
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
)
if reduction_type == "dot":
dense_sizes = self.dense_size_list()
assert len(dense_sizes) >= 3
xy_sizes_only = [
size for size in dense_sizes if "X" in size or "Y" in size
]
dense_size_str = f"[{', '.join(xy_sizes_only)}]"
self.body.writeline(
f"{accumulator} = tl.full({dense_size_str}, {default}, {acc_type})"
)
else:
self.body.writeline(
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
)
if reduction_type in ("argmax", "argmin"):
accumulator_index = f"_{result_var}_index"
@ -2717,9 +2828,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
else:
combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype)
updated = combine_fn(accumulator, value)
self.compute.writeline(
f"{accumulator} = {where_cond(updated, accumulator)}"
)
if reduction_type == "dot":
self.compute.writeline(f"{accumulator} = {updated}")
else:
self.compute.writeline(
f"{accumulator} = {where_cond(updated, accumulator)}"
)
if src_dtype == torch.bool:
# This is only really used for aten.any. It changes the
@ -3658,6 +3772,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
"signature": triton_meta_signature,
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
"constants": {},
"native_matmul": (
torch._inductor.config.triton.enable_native_matmul
and ("tl.dot" in str(self.body) or "tl.dot" in str(self.compute))
),
}
# Skip memory optimization for forward of the training loop where we expect

View File

@ -1197,6 +1197,9 @@ class triton:
# For best results, this should be used with prefer_nd_tiling.
tile_reductions: bool = False
# Codegen matmul natively with tl.dot without calling template.
enable_native_matmul: bool = True
# should we stop a fusion to allow better tiling?
tiling_prevents_pointwise_fusion = True
tiling_prevents_reduction_fusion = True
@ -1279,7 +1282,7 @@ class triton:
inject_relu_bug_TESTING_ONLY: Optional[str] = None
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
codegen_upcast_to_fp32 = True
codegen_upcast_to_fp32 = False
# Whether persistent matmul kernels should be enabled this flag only has effect when on h100
# with a version of triton new enough to support TMA

View File

@ -347,6 +347,10 @@ class DtypePropagationOpsHandler:
# TODO - way of registering dtype for op in backend
return torch.int32
@staticmethod
def dot(x: DTypeArg, y: DTypeArg) -> torch.dtype:
return promote_types([x, y])
@staticmethod
def inline_asm_elementwise(
*inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1

View File

@ -63,7 +63,11 @@ from .micro_pipeline_tp import micro_pipeline_tp_pass
from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
from .split_cat import POST_GRAD_PATTERNS
from ..lowering import (
make_pointwise,
make_reduction,
transform_args
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
@ -1423,6 +1427,53 @@ def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
match.replace_by_example(repl, [inp, mat1, mat2])
def native_matmul_extra_check(match):
"""
Currently only enable native matmul for triton on Nvidia GPU.
"""
return (
match.kwargs["mat1"].meta["val"].device.type == "cuda"
and config.cuda_backend == "triton"
)
@register_lowering_pattern(
CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
extra_check=native_matmul_extra_check,
)
def lower_mm_native(match: Match, mat1, mat2):
mat1 = L[aten.unsqueeze](mat1, -1)
mat2 = L[aten.unsqueeze](mat2, 0)
args, kwargs = transform_args(
args=[mat1, mat2],
kwargs={},
broadcast=True,
type_promotion_kind=None,
convert_input_to_bool=False
) # Handles broadcasting the arguments
mul_pointwise = make_pointwise(ops.dot)(*args)
dot_reduction = make_reduction("dot")(mul_pointwise, 1,)
return dot_reduction
@register_lowering_pattern(
CallFunction(aten.bmm, KeywordArg("mat1"), KeywordArg("mat2")),
extra_check=native_matmul_extra_check,
)
def lower_bmm_native(match: Match, mat1, mat2):
mat1 = L[aten.unsqueeze](mat1, -1)
mat2 = L[aten.unsqueeze](mat2, 1)
args, kwargs = transform_args(
args=[mat1, mat2],
kwargs={},
broadcast=True,
type_promotion_kind=None,
convert_input_to_bool=False
) # Handles broadcasting the arguments
mul_pointwise = make_pointwise(ops.dot)(*args)
dot_reduction = make_reduction("dot")(mul_pointwise, 2,)
return dot_reduction
def is_valid_addmm_fusion(match):
mat1, mat2 = match.args
inp = match.kwargs["inp"]

View File

@ -1073,6 +1073,7 @@ REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
"min": ops_wrapper("minimum"),
"prod": ops_wrapper("mul"),
"sum": ops_wrapper("add"),
"dot": ops_wrapper("add"),
"xor_sum": ops_wrapper("bitwise_xor"),
}
@ -1601,6 +1602,7 @@ class Reduction(Loops):
return {
"sum": zero,
"prod": one,
"dot": zero,
"xor_sum": zero,
"any": zero,
"welford_reduce": (zero, zero, zero),

View File

@ -686,6 +686,10 @@ class OpsHandler(Generic[T]):
def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T:
raise NotImplementedError
# triton-only
def dot(self, x: T, y: T) -> T:
raise NotImplementedError
# triton-only
def inline_asm_elementwise(
self,

View File

@ -2309,9 +2309,106 @@ def pointwise(
filename=filename,
)
def make_matmul_triton_config(sizes: dict, num_warps: int, num_stages: int):
config = {
"XBLOCK": sizes.get("x"),
"YBLOCK": sizes.get("y"),
"ZBLOCK": sizes.get("z"),
"R0_BLOCK": sizes.get("r"),
}
# Remove keys with None values (i.e., missing in sizes)
config = {k: v for k, v in config.items() if v is not None}
return Config(config, num_warps=num_warps, num_stages=num_stages)
# Each entry is: (sizes_dict, num_warps, num_stages)
triton_native_mm_configs = [
({"x": 32, "y": 32, "r": 16}, 2, 1),
({"x": 32, "y": 32, "r": 128}, 4, 2),
({"x": 32, "y": 64, "r": 32}, 8, 5),
({"x": 64, "y": 32, "r": 32}, 8, 5),
({"x": 64, "y": 32, "r": 128}, 4, 5),
({"x": 64, "y": 64, "r": 16}, 4, 2),
({"x": 64, "y": 64, "r": 32}, 4, 2),
({"x": 64, "y": 64, "r": 64}, 8, 3),
({"x": 64, "y": 64, "r": 128}, 4, 5),
({"x": 64, "y": 128, "r": 32}, 4, 3),
({"x": 64, "y": 128, "r": 32}, 8, 4),
({"x": 64, "y": 128, "r": 64}, 4, 3),
({"x": 64, "y": 128, "r": 128}, 4, 4),
({"x": 128, "y": 64, "r": 32}, 4, 3),
({"x": 128, "y": 64, "r": 32}, 8, 4),
({"x": 128, "y": 128, "r": 32}, 8, 2),
({"x": 128, "y": 128, "r": 32}, 4, 3),
({"x": 128, "y": 128, "r": 64}, 4, 3),
({"x": 128, "y": 128, "r": 64}, 8, 5),
]
triton_native_persistent_mm_configs = [
({"x": 32, "y": 32}, 2, 1),
({"x": 32, "y": 32}, 4, 2),
({"x": 32, "y": 64}, 8, 5),
({"x": 64, "y": 32}, 8, 5),
({"x": 64, "y": 32}, 4, 5),
({"x": 64, "y": 64}, 4, 2),
({"x": 64, "y": 64}, 8, 3),
({"x": 64, "y": 64}, 4, 5),
({"x": 64, "y": 128}, 4, 3),
({"x": 64, "y": 128}, 8, 4),
({"x": 64, "y": 128}, 4, 4),
({"x": 128, "y": 64}, 4, 3),
({"x": 128, "y": 64}, 8, 4),
({"x": 128, "y": 128}, 8, 2),
({"x": 128, "y": 128}, 4, 3),
({"x": 128, "y": 128}, 8, 5),
]
triton_native_bmm_configs = [
({"z": 1, "x": 32, "y": 32, "r": 16}, 2, 1),
({"z": 1, "x": 32, "y": 32, "r": 128}, 4, 2),
({"z": 1, "x": 32, "y": 64, "r": 32}, 8, 5),
({"z": 1, "x": 64, "y": 32, "r": 32}, 8, 5),
({"z": 1, "x": 64, "y": 32, "r": 128}, 4, 5),
({"z": 1, "x": 64, "y": 64, "r": 16}, 4, 2),
({"z": 1, "x": 64, "y": 64, "r": 32}, 4, 2),
({"z": 1, "x": 64, "y": 64, "r": 64}, 8, 3),
({"z": 1, "x": 64, "y": 64, "r": 128}, 4, 5),
({"z": 1, "x": 64, "y": 128, "r": 32}, 4, 3),
({"z": 1, "x": 64, "y": 128, "r": 32}, 8, 4),
({"z": 1, "x": 64, "y": 128, "r": 64}, 4, 3),
({"z": 1, "x": 64, "y": 128, "r": 128}, 4, 4),
({"z": 1, "x": 128, "y": 64, "r": 32}, 4, 3),
({"z": 1, "x": 128, "y": 64, "r": 32}, 8, 4),
({"z": 1, "x": 128, "y": 128, "r": 32}, 8, 2),
({"z": 1, "x": 128, "y": 128, "r": 32}, 4, 3),
({"z": 1, "x": 128, "y": 128, "r": 64}, 4, 3),
({"z": 1, "x": 128, "y": 128, "r": 64}, 8, 5),
]
triton_native_persistent_bmm_configs = [
({"z": 1, "x": 32, "y": 32}, 2, 1),
({"z": 1, "x": 32, "y": 32}, 4, 2),
({"z": 1, "x": 32, "y": 64}, 8, 5),
({"z": 1, "x": 64, "y": 32}, 8, 5),
({"z": 1, "x": 64, "y": 32}, 4, 5),
({"z": 1, "x": 64, "y": 64}, 4, 2),
({"z": 1, "x": 64, "y": 64}, 8, 3),
({"z": 1, "x": 64, "y": 64}, 4, 5),
({"z": 1, "x": 64, "y": 128}, 4, 3),
({"z": 1, "x": 64, "y": 128}, 8, 4),
({"z": 1, "x": 64, "y": 128}, 4, 4),
({"z": 1, "x": 128, "y": 64}, 4, 3),
({"z": 1, "x": 128, "y": 64}, 8, 4),
({"z": 1, "x": 128, "y": 128}, 8, 2),
({"z": 1, "x": 128, "y": 128}, 4, 3),
({"z": 1, "x": 128, "y": 128}, 8, 5),
]
def _reduction_configs(
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
*,
size_hints: dict[str, int],
inductor_meta: dict[str, Any],
triton_meta: dict[str, Any],
) -> list[Config]:
reduction_hint = inductor_meta.get("reduction_hint", None)
@ -2340,6 +2437,20 @@ def _reduction_configs(
MAX_R0_BLOCK = 1024
register_intensive = True
if triton_meta["native_matmul"]:
if len(size_hints) == 3:
return [
make_matmul_triton_config(sizes, num_warps, num_stages)
for sizes, num_warps, num_stages in triton_native_mm_configs
]
elif len(size_hints) == 4:
return [
make_matmul_triton_config(sizes, num_warps, num_stages)
for sizes, num_warps, num_stages in triton_native_bmm_configs
]
else:
raise NotImplementedError("native matmul only supports mm/bmm pattern")
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
# For 3D case with tiling scores, create an adapted version
if "y" in size_hints:
@ -2494,7 +2605,10 @@ def reduction(
assert triton_meta is not None
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
configs = _reduction_configs(
size_hints=size_hints, inductor_meta=inductor_meta, triton_meta=triton_meta
)
return cached_autotune(
size_hints,
configs=configs,
@ -2530,12 +2644,16 @@ def cooperative_reduction(
assert split <= TRITON_MAX_RSPLIT
if inductor_meta["persistent_reduction"]:
configs = _persistent_reduction_configs(
{"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta
{"x": xnumel, "r0_": rnumel // split},
reduction_hint,
inductor_meta,
triton_meta
)
else:
configs = _reduction_configs(
size_hints={"x": xnumel, "r0_": rnumel // split},
inductor_meta=inductor_meta,
triton_meta=triton_meta,
)
for config in configs:
config.kwargs["RSPLIT"] = split
@ -2555,12 +2673,27 @@ def _persistent_reduction_configs(
size_hints,
reduction_hint=False,
inductor_meta=None,
triton_meta=None,
):
xnumel = size_hints["x"]
rnumel = get_total_reduction_numel(size_hints)
MAX_PERSISTENT_BLOCK_NUMEL = 4096
if triton_meta["native_matmul"]:
if len(size_hints) == 3:
return [
make_matmul_triton_config(sizes, num_warps, num_stages)
for sizes, num_warps, num_stages in triton_native_persistent_mm_configs
]
elif len(size_hints) == 4:
return [
make_matmul_triton_config(sizes, num_warps, num_stages)
for sizes, num_warps, num_stages in triton_native_persistent_bmm_configs
]
else:
raise NotImplementedError("native matmul only supports mm/bmm pattern")
if "y" not in size_hints:
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
@ -2625,7 +2758,9 @@ def persistent_reduction(
if inductor_meta.get("no_x_dim"):
size_hints["x"] = 1
configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta)
configs = _persistent_reduction_configs(
size_hints, reduction_hint, inductor_meta, triton_meta
)
return cached_autotune(
size_hints,
@ -2654,7 +2789,9 @@ def split_scan(
if len(size_hints) != 2:
raise NotImplementedError(f"size_hints: {size_hints}")
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
configs = _reduction_configs(
size_hints=size_hints, inductor_meta=inductor_meta, triton_meta=triton_meta
)
# Fixup configs to enforce the minimum Rn_BLOCK size
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)