Update auto-tuning support for _scaled_grouped_mm (#150944)

1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Updated meta registration
7. Update synthetic offsets creation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944
Approved by: https://github.com/ngimel
This commit is contained in:
Aleksandar Samardžić
2025-06-06 19:19:49 +00:00
committed by PyTorch MergeBot
parent 1339e88105
commit 09328eb02f
8 changed files with 676 additions and 385 deletions

View File

@ -1532,7 +1532,7 @@ namespace {
"D, arg ",
arg_idx);
TORCH_CHECK(
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * scale_multiplier,
"scale must have the same length as mat for arg ",
@ -1545,8 +1545,8 @@ namespace {
"D for arg ",
arg_idx);
TORCH_CHECK(
scale.stride(1),
"scale_a must be contiguous in the last dimension for arg ",
scale.stride(1) == 1,
"scale must be contiguous in the last dimension for arg ",
arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(0),
@ -1610,6 +1610,7 @@ bool use_fast_accum) {
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet");
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
if (offs.has_value()) {

View File

@ -1616,7 +1616,7 @@ class TestFP8Matmul(TestCase):
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
self.assertEqual(out, out_ref)
self.assertEqual(out, out_ref, atol=1e-1, rtol=1e-2)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@ -1626,14 +1626,19 @@ class TestFP8Matmul(TestCase):
@parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.t(), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
@ -1657,7 +1662,7 @@ class TestFP8Matmul(TestCase):
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
@ -1666,11 +1671,16 @@ class TestFP8Matmul(TestCase):
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
f = torch._scaled_grouped_mm
f = torch.compile(f, dynamic=False) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
@ -1682,7 +1692,7 @@ class TestFP8Matmul(TestCase):
ascalelist.append(scale_a[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@ -1694,16 +1704,21 @@ class TestFP8Matmul(TestCase):
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
@ -1719,20 +1734,25 @@ class TestFP8Matmul(TestCase):
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
for check_zero_size in (True, False):
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
f = torch.compile(
f,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
}) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
@ -1743,7 +1763,7 @@ class TestFP8Matmul(TestCase):
bscalelist.append(scale_b[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)

View File

@ -217,7 +217,6 @@ def mark_nodes_dislike_padding(
aten.convolution,
aten.convolution_backward,
aten._scaled_mm,
aten._scaled_grouped_mm,
]
)
# what's a better way to collect the reduction ops?

View File

@ -38,7 +38,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
@SymbolicGridFn
def persistent_grouped_mm_grid(m, n, meta):
def persistent_grouped_mm_grid(*args):
meta = args[-1]
return (meta["NUM_SMS"], 1, 1)

View File

@ -10,7 +10,6 @@ from torch.utils._triton import has_triton_tma_device
from ..ir import ChoiceCaller, Layout, TensorBox
from ..lowering import register_lowering
from ..runtime.runtime_utils import next_power_of_2
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
@ -20,7 +19,7 @@ from ..select_algorithm import (
from ..utils import (
get_gpu_shared_memory,
get_num_sms,
get_tma_workspace_arg,
has_free_symbols,
use_aten_gemm_kernels,
)
from .mm_common import (
@ -52,7 +51,7 @@ _NV_CONFIGS = [
num_stages=num_stages,
num_warps=num_warps,
)
for block_size_m in [64, 128]
for block_size_m in [16, 32, 64, 128]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
@ -81,11 +80,11 @@ _AMD_CONFIGS = [
]
def scaled_grouped_mm_configs():
def grouped_mm_configs():
return _AMD_CONFIGS if torch.version.hip else _NV_CONFIGS
def early_config_prune(configs, named_args):
def early_config_prune(g, m, configs, named_args):
dtsize = 1
pruned_configs = []
for config in configs:
@ -98,14 +97,26 @@ def early_config_prune(configs, named_args):
config.num_warps,
getattr(config, "num_consumer_groups", 0),
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
# 1. Prune NV configs depending on g and m.
if not torch.version.hip:
if not has_free_symbols((g, m)):
a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"]
m_avg = m // g if a_is_2d and not b_is_2d else m
if m_avg <= 16:
if BLOCK_M > 32:
continue
elif m_avg <= 32:
if BLOCK_M > 64:
continue
elif m_avg <= 64:
if BLOCK_M <= 16:
continue
else:
if BLOCK_M <= 32:
continue
# 2. make sure we have enough smem
max_shared_memory = get_gpu_shared_memory()
if torch.version.hip:
@ -117,39 +128,7 @@ def early_config_prune(configs, named_args):
use_warp_specialization = num_consumer_groups >= 1
M_PER_GROUP = M // G
MIN_M_TILES = 32 if torch.version.hip else 64
# 2. make sure we don't load M tiles that are too big
if (
not use_warp_specialization
and BLOCK_M > MIN_M_TILES
and BLOCK_M > (M_PER_GROUP * 2)
):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = get_num_sms()
N_TILES = N // BLOCK_N
MIN_N_TILES = 32 if torch.version.hip else 64
# 4. make sure we don't load N tiles that are too big
if (
not use_warp_specialization
and BLOCK_N > MIN_N_TILES
and M * N_TILES < num_sm
):
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
# 7. make sure we can partition for ws
# 3. make sure we can partition for ws
if use_warp_specialization:
if num_warps != 4:
continue
@ -166,47 +145,129 @@ def early_config_prune(configs, named_args):
# Copied from fbgemm grouped_gemm.py
triton_scaled_grouped_mm_source = r"""
{{def_kernel("a_ptr", "b_ptr", "a_scale_ptr", "b_scale_ptr", "m_sizes")}}
triton_grouped_mm_source = r"""
{%- if A_IS_2D or B_IS_2D %}
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}}
{%- else %}
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}}
{%- endif %}
tidx = tl.program_id(0)
dtype = tl.float8e4nv
TMA_SIZE: tl.constexpr = tl.constexpr(128)
{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %}
{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %}
{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %}
workspace_base = ws_ptr + tidx * 2 * TMA_SIZE
c_desc_ptr = None
{%- if A_IS_2D %}
{%- if B_IS_2D %}
G = {{size("offsets_ptr", 0)}}
{%- else %}
G = {{size("b_ptr", 0)}}
{%- endif %}
{%- else %}
{%- if B_IS_2D %}
G = {{size("a_ptr", 0)}}
{%- else %}
G = {{size("a_ptr", 0)}}
{%- endif %}
{%- endif %}
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
# the b_ptr tensor is given with its last two dims transposed, revert here
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=a_ptr,
load_size=[BLOCK_M, BLOCK_K],
global_size=[M, K],
element_ty=a_ptr.dtype.element_ty,
M = {{size("a_ptr", -2)}}
N = {{size("b_ptr", -1)}}
K = {{size("a_ptr", -1)}}
A_STRIDE_M = {{stride("a_ptr", -2)}}
A_STRIDE_K = {{stride("a_ptr", -1)}}
{%- if not A_IS_2D %}
A_STRIDE_G = {{stride("a_ptr", 0)}}
SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}}
{%- endif %}
B_STRIDE_N = {{stride("b_ptr", -1)}}
B_STRIDE_K = {{stride("b_ptr", -2)}}
{%- if not B_IS_2D %}
B_STRIDE_G = {{stride("b_ptr", 0)}}
SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}}
{%- endif %}
# fixme: a_desc = tl.make_tensor_descriptor(
a_desc = tl._experimental_make_tensor_descriptor(
a_ptr,
{%- if A_IS_2D %}
shape=[M, K],
# fixme: strides=[A_STRIDE_M, A_STRIDE_K],
strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
block_shape=[BLOCK_M, BLOCK_K],
{%- else %}
shape=[G, M, K],
# fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
block_shape=[1, BLOCK_M, BLOCK_K],
{%- endif %}
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=b_desc_ptr,
global_address=b_ptr,
load_size=[BLOCK_N, BLOCK_K],
global_size=[N * G, K],
element_ty=b_ptr.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
M_end_offset = 0
# fixme: b_desc = tl.make_tensor_descriptor(
b_desc = tl._experimental_make_tensor_descriptor(
b_ptr,
{%- if B_IS_2D %}
shape=[N, K],
# fixme: strides=[B_STRIDE_N, B_STRIDE_K],
strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
block_shape=[BLOCK_N, BLOCK_K],
{%- else %}
shape=[G, N, K],
# fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
block_shape=[1, BLOCK_N, BLOCK_K],
{%- endif %}
)
{%- if M_IS_VARYING %}
m_end_offset = 0
{%- endif %}
{%- if N_IS_VARYING %}
n_end_offset = 0
{%- endif %}
{%- if K_IS_VARYING %}
k_end_offset = 0
{%- endif %}
iterated_tiles = 0
for g in tl.range(G):
{%- if M_IS_VARYING %}
# Move across groups
M_start_offset = M_end_offset
M_end_offset = tl.load(m_sizes + g)
m_size = M_end_offset - M_start_offset
m_start_offset = m_end_offset
m_end_offset = tl.load(offsets_ptr + g)
m_size = m_end_offset - m_start_offset
m_scale_start_offset = m_start_offset
{%- else %}
m_start_offset = 0
m_size = M
m_scale_start_offset = g * M
{%- endif %}
{%- if N_IS_VARYING %}
# Move across groups
n_start_offset = n_end_offset
n_end_offset = tl.load(offsets_ptr + g)
n_size = n_end_offset - n_start_offset
n_scale_start_offset = n_start_offset
{%- else %}
n_start_offset = 0
n_size = N
n_scale_start_offset = g * N
{%- endif %}
if m_size > 0 and n_size > 0:
{%- if K_IS_VARYING %}
# Move across groups
k_start_offset = k_end_offset
k_end_offset = tl.load(offsets_ptr + g)
k_size = k_end_offset - k_start_offset
{%- else %}
k_start_offset = 0
k_size = K
{%- endif %}
if m_size > 0:
N_start_offset = g.to(tl.int64) * N
n_size = N
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_N)
num_tiles = num_m_tiles * num_n_tiles
@ -219,64 +280,111 @@ triton_scaled_grouped_mm_source = r"""
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_K == 0)
if USE_TMA_LOAD:
m_offset = (M_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_K):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_M, BLOCK_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_N, BLOCK_K],
dtype,
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
else:
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
accumulator += tl.dot(a, b.T)
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K
{%- if USE_TMA_LOAD %}
m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
for k_offset in range(0, k_size, BLOCK_K):
{%- if A_IS_2D %}
a = a_desc.load([m_offset, k_start_offset + k_offset])
{%- else %}
a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K)
{%- endif %}
{%- if B_IS_2D %}
b = b_desc.load([n_offset, k_start_offset + k_offset])
{%- else %}
b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K)
{%- endif %}
{%- if K_IS_VARYING %}
if k_offset + BLOCK_K > k_size:
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
a = tl.where(group_offs_k < k_size, a, 0)
b = tl.where(group_offs_k < k_size, b, 0)
{%- endif %}
{%- if USE_FAST_ACCUM %}
accumulator = tl.dot(a, b.T, accumulator)
{%- else %}
accumulator += tl.dot(a, b.T)
{%- endif %}
{%- else %}
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = k_start_offset + tl.arange(0, BLOCK_K)
a_ptrs = (
a_ptr
{%- if not A_IS_2D %}
+ g * A_STRIDE_G
{%- endif %}
+ (m_start_offset + offs_am[:, None]) * A_STRIDE_M
+ offs_k[None, :] * A_STRIDE_K
)
b_ptrs = (
b_ptr
{%- if not B_IS_2D %}
+ g * B_STRIDE_G
{%- endif %}
+ (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
+ offs_k[None, :] * B_STRIDE_K
)
for k_offset in range(0, k_size, BLOCK_K):
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
if k_offset + BLOCK_K > k_size:
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
a = tl.where(group_offs_k < k_size, a, 0)
b = tl.where(group_offs_k < k_size, b, 0)
{%- if USE_FAST_ACCUM %}
accumulator = tl.dot(a, b.T, accumulator)
{%- else %}
accumulator += tl.dot(a, b.T)
{%- endif %}
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K
{%- endif %}
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
a_scale = tl.load(
a_scale_ptr + M_start_offset + offs_am[:, None],
scale_a = tl.load(
scale_a_ptr
{%- if A_IS_2D %}
+ m_scale_start_offset
{%- else %}
+ g * SCALE_A_STRIDE_G
{%- endif %}
+ offs_am[:, None],
mask=offs_am[:, None] < m_size,
)
b_scale = tl.load(
b_scale_ptr + N_start_offset + offs_bn[None, :],
scale_b = tl.load(
scale_b_ptr
{%- if B_IS_2D %}
+ n_scale_start_offset
{%- else %}
+ g * SCALE_B_STRIDE_G
{%- endif %}
+ offs_bn[None, :],
mask=offs_bn[None, :] < n_size,
)
c = accumulator.to(tl.float32) * a_scale * b_scale
c = accumulator.to(tl.float32) * scale_a * scale_b
idx_m = (M_start_offset + offs_am[:, None])
{%- if M_IS_VARYING %}
idx_m = (m_start_offset + offs_am[:, None])
{%- else %}
idx_m = offs_am[:, None]
{%- endif %}
{%- if N_IS_VARYING %}
idx_n = (n_start_offset + offs_bn[None, :])
{%- else %}
idx_n = offs_bn[None, :]
{%- endif %}
mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size
{%- if M_IS_VARYING or N_IS_VARYING %}
{{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}}
{%- else %}
{{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16)}}
{%- endif %}
tidx += NUM_SMS
iterated_tiles += num_tiles
@ -286,7 +394,7 @@ triton_scaled_grouped_mm_source = r"""
triton_scaled_grouped_mm_template = TritonTemplate(
name="scaled_grouped_mm",
grid=persistent_grouped_mm_grid,
source=triton_scaled_grouped_mm_source,
source=triton_grouped_mm_source,
)
@ -297,7 +405,9 @@ def grouped_mm_args(
layout=None,
out_dtype=None,
):
mat1, mat2, offs = realize_inputs(mat1, mat2, offs)
mat1, mat2 = realize_inputs(mat1, mat2)
if offs is not None:
realize_inputs(offs)
mat1_size = mat1.get_size()
mat2_size = mat2.get_size()
@ -348,46 +458,171 @@ def can_use_triton_kernel(
mat_b: TensorBox,
offs: Optional[TensorBox],
bias: Optional[TensorBox],
scale_result: Optional[TensorBox],
) -> bool:
a_shape = mat_a.get_size()
b_shape = mat_b.get_size()
a_stride = mat_a.get_stride()
b_stride = mat_b.get_stride()
if not has_triton_tma_device():
return False
# A must be contiguous 2d
a_layout_ok = (
len(a_shape) == 2
and a_stride[1] == 1
and a_stride[0] == a_shape[1]
and a_shape[1] >= 32
)
# The _grouped_mm()/_scaled_grouped_mm() operator do not support
# bias nor scale_result yet.
if bias is not None:
return False
if scale_result is not None:
return False
# B must be contiguous 3d with transposed last dimension
b_layout_ok = (
len(b_shape) == 3
and b_stride[2] == b_shape[1]
and b_stride[1] == 1
and b_stride[0] == (b_shape[1] * b_shape[2])
and b_shape[1] >= 32
)
return (
offs is not None
and bias is None
and has_triton_tma_device()
and a_layout_ok
and b_layout_ok
)
if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2:
return offs is not None
else:
return offs is None
def create_offsets(x, m1_size, m2_size, offs_size):
assert len(m1_size) == 2 and len(m2_size) == 3, (
"Autotuning _scaled_grouped_mm is only implemented for 2d-3d tensors"
m1_is_2d = len(m1_size) == 2
m2_is_2d = len(m2_size) == 2
if m1_is_2d:
if m2_is_2d:
k = V.graph.sizevars.size_hint(m1_size[1])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = k / noffs
return torch.linspace(
step, k, noffs, dtype=x.get_dtype(), device=x.get_device()
)
else:
m = V.graph.sizevars.size_hint(m1_size[0])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = m / noffs
return torch.linspace(
step, m, noffs, dtype=x.get_dtype(), device=x.get_device()
)
else:
if m2_is_2d:
n = V.graph.sizevars.size_hint(m2_size[0])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = n / noffs
return torch.linspace(
step, n, noffs, dtype=x.get_dtype(), device=x.get_device()
)
else:
return None
def _tuned_grouped_mm_common(
operator_name: str,
algorithm_name: str,
extern_kernel_choice: ExternKernelChoice,
kernel_template: TritonTemplate,
mat_a: TensorBox,
mat_b: TensorBox,
scale_a: Optional[TensorBox] = None,
scale_b: Optional[TensorBox] = None,
offs: Optional[TensorBox] = None,
bias: Optional[TensorBox] = None,
scale_result: Optional[TensorBox] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: Optional[bool] = False,
layout: Optional[Layout] = None,
) -> TensorBox:
assert (scale_a is None) == (scale_b is None)
assert scale_result is None or scale_a is not None
m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args(
mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype
)
counters["aten_mm_info"][operator_name] += 1
log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s"
log.info(
log_message,
m1_size,
m2_size,
mat_a.get_dtype(),
mat_b.get_dtype(),
layout,
)
check_supported_striding(mat_a, mat_b)
# workaround for Inductor not supporting optional tensor input arguments
input_nodes: list[Any] = [mat_a, mat_b]
if scale_a is not None:
input_nodes.append(realize_inputs(scale_a))
if scale_b is not None:
input_nodes.append(realize_inputs(scale_b))
if offs is not None:
input_nodes.append(realize_inputs(offs))
aten_choice = extern_kernel_choice.bind(
input_nodes,
layout,
out_dtype=out_dtype,
use_fast_accum=use_fast_accum,
)
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.append(aten_choice)
_, is_nonzero = _is_static_problem(layout)
# Checking only for the equality of correspoding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result):
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
g = offs.get_size()[0]
V.graph.sizevars.guard_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.guard_equals(g1, g2)
V.graph.sizevars.guard_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.guard_equals(g1, g2)
V.graph.sizevars.guard_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.guard_equals(g1, g2)
V.graph.sizevars.guard_equals(k1, k2)
a_is_2d, b_is_2d = False, False
kwargs = {
"A_IS_2D": a_is_2d,
"B_IS_2D": b_is_2d,
"USE_FAST_ACCUM": use_fast_accum,
"NUM_SMS": get_num_sms(),
"USE_TMA_LOAD": True,
}
for config in early_config_prune(g, m, grouped_mm_configs(), kwargs):
kernel_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
num_stages=config.num_stages,
num_warps=config.num_warps,
**kwargs,
**config.kwargs,
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None
),
}
return autotune_select_algorithm(
algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns
)
m = V.graph.sizevars.size_hint(m1_size[0])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = m / noffs
return torch.linspace(step, m, noffs, dtype=x.get_dtype(), device=x.get_device())
@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None)
@ -403,77 +638,21 @@ def tuned_scaled_grouped_mm(
use_fast_accum: bool = False,
layout: Optional[Layout] = None,
) -> TensorBox:
m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args(
mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype
)
"""Auto-tuning for _scaled_grouped_mm() operator."""
counters["aten_mm_info"]["aten._scaled_grouped_mm.default"] += 1
log.info(
"Tuned aten._scaled_grouped_mm.default: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
m1_size,
m2_size,
mat_a.get_dtype(),
mat_b.get_dtype(),
return _tuned_grouped_mm_common(
"aten._scaled_grouped_mm.default",
"scaled_grouped_mm",
aten__scaled_grouped_mm,
triton_scaled_grouped_mm_template,
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
scale_result,
out_dtype,
use_fast_accum,
layout,
)
check_supported_striding(mat_a, mat_b)
scale_a, scale_b = realize_inputs(scale_a, scale_b)
# workaround for Inductor not supporting optional tensor input arguments
input_nodes: list[Any] = [mat_a, mat_b, scale_a, scale_b]
if offs is not None:
input_nodes.append(realize_inputs(offs))
if bias is not None:
input_nodes.append(realize_inputs(bias))
aten_choice = aten__scaled_grouped_mm.bind(
input_nodes,
layout,
out_dtype=out_dtype,
use_fast_accum=use_fast_accum,
)
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.append(aten_choice)
_, is_nonzero = _is_static_problem(layout)
if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias):
m, k1 = m1_size
g, k2, n = m2_size
k = V.graph.sizevars.guard_equals(k1, k2)
kwargs = {
"G": g,
"M": m,
"M_BUCKET": next_power_of_2(m),
"N": n,
"K": k,
"NUM_SMS": get_num_sms(),
"USE_TMA_LOAD": True,
"USE_TMA_STORE": False,
"USE_FAST_ACCUM": use_fast_accum,
}
for config in early_config_prune(scaled_grouped_mm_configs(), kwargs):
triton_scaled_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat_a.get_device(),
),
num_stages=config.num_stages,
num_warps=config.num_warps,
**kwargs,
**config.kwargs,
)
input_gen_fns = {
4: lambda x: create_offsets(x, m1_size, m2_size, offs.get_size()),
}
return autotune_select_algorithm(
"scaled_grouped_mm", choices, input_nodes, layout, input_gen_fns=input_gen_fns
)

View File

@ -33,6 +33,7 @@ import torch
from torch._dynamo.utils import set_feature_use
from torch._prims_common import compute_required_storage_length
from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import triton_set_allocator
from ..triton_bundler import TritonBundler
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
@ -1114,6 +1115,8 @@ class CachingAutotuner(KernelInterface):
**self.configs[0].kwargs,
)
triton_set_allocator(self.triton_meta["device"])
if len(self.launchers) != 1:
if len(self.launchers) == 0:
start_time = time.time_ns()

View File

@ -7292,7 +7292,7 @@ def sigmoid(self: Tensor) -> Tensor:
return torch.empty_like(self, dtype=result_dtype)
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
def _compute_grouped_mm_output_size(mat1, mat2, offs):
mat1_is_2d = mat1.dim() == 2
mat2_is_2d = mat2.dim() == 2
@ -7316,33 +7316,205 @@ def _compute_grouped_gemm_output_size(mat1, mat2, offs):
return mat1.size(0), mat1.size(1), mat2.size(-1)
def _meta_grouped_mm_common(
mat_a: Tensor,
mat_b: Tensor,
scale_a: Optional[torch.Tensor],
scale_b: Optional[torch.Tensor],
offs: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
scale_result: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
):
torch._check(
(scale_a is None) == (scale_b is None),
lambda: "Either both scale factors are given, or none",
)
scaled = scale_a is not None and scale_b is not None
# Implementing all the checks from
# _grouped_mm_cuda()/_scaled_grouped_mm_cuda() code in
# aten/src/ATen/native/cuda/Blas.cpp.
if scaled:
torch._check(
mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn,
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
)
else:
torch._check(
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
)
torch._check(
mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}",
)
mat_a_is_2d = mat_a.dim() == 2
mat_b_is_2d = mat_b.dim() == 2
torch._check(
mat_a.shape[-1] % 16 == 0,
lambda: f"Expected mat_a.shape[-1] to be divisible by 16, but got mat_a.shape[-1]={mat_a.shape[1]}",
)
torch._check(
mat_b.shape[-2] % 16 == 0 and mat_b.shape[-1] % 16 == 0,
lambda: f"Expected mat_b.shape[-2] and mat_b.shape[-1] to be both divisble by 16 but got {mat_b.shape[-2]} and {mat_b.shape[-1]}", # noqa: B950
)
if scaled:
def is_row_major(mat):
mat_stride = mat.stride()
return mat_stride[-2] > 1 and mat_stride[-1] == 1
def is_col_major(mat):
mat_stride = mat.stride()
return mat_stride[-2] == 1 and mat_stride[-1] > 1
torch._check(
is_row_major(mat_a),
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}",
)
torch._check(
is_col_major(mat_b),
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}",
)
def check_valid_strides(mat_name, mat):
end_dim = mat.dim() - 1
alignment = 16 / mat.element_size()
mat_stride = mat.stride()
if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
1, mat.shape[end_dim - 1]
):
torch._check(
mat_stride[end_dim] % alignment == 0,
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.",
)
elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
1, mat.shape[end_dim]
):
torch._check(
mat_stride[end_dim - 1] % alignment == 0,
lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950
)
else:
torch._check(
False,
lambda: f"Expected {mat_name} to have a contiguous dimension and not be mat_a-overlapping, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
)
check_valid_strides("mat_a", mat_a)
check_valid_strides("mat_b", mat_b)
if scale_a is not None and scale_b is not None:
torch._check(
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
)
def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
if mat.dim() == 2:
torch._check(
scale.dim() == 1,
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.is_contiguous(),
lambda: f"Expected {scale_name} to be contiguous.",
)
torch._check(
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
)
else:
torch._check(
scale.dim() == 2,
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
)
torch._check(
scale.stride(1) == 1,
lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
)
torch._check(
scale.shape[0] == mat.shape[0],
lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
)
torch._check(
scale.shape[1] == mat.shape[1 + scaled_dim],
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.",
)
scale_multiplier = (
offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1
)
check_scale("scale_a", scale_a, mat_a, 0, scale_multiplier)
check_scale("scale_b", scale_b, mat_b, 1, scale_multiplier)
torch._check(
scale_result is None,
lambda: "Scale result tensor provided, but it is not supported yet.",
)
if mat_a_is_2d or mat_b_is_2d:
torch._check(
offs is not None,
lambda: f"Offsets tensor not provided, but is needed for {mat_a.dim()}D/{mat_b.dim()}D multiplicand layouts.",
)
if offs is not None: # to silence Mypy
torch._check(
offs.dim() == 1,
lambda: f"Offsets tensor must be 1D, but got offs.dim()={offs.dim()}.",
)
torch._check(
offs.dtype == torch.int32,
lambda: f"Offsets tensor must be integer (int32) tensor, but got {offs.dtype}.",
)
else:
torch._check(
offs is None,
lambda: "Offsets tensor provided, but is not needed for 3D/3D multiplicand layouts.",
)
torch._check(
bias is None,
lambda: "Bias tensor provided, but it is not supported yet.",
)
torch._check(
out_dtype is None or out_dtype == torch.bfloat16,
lambda: "If output dtype provided, it must be torch.bfloat16.",
)
out_size = _compute_grouped_mm_output_size(mat_a, mat_b, offs)
out_dtype = out_dtype or mat_a.dtype
return torch.empty(out_size, dtype=out_dtype, device=mat_a.device)
@register_meta(aten._grouped_mm)
@out_wrapper()
def grouped_mm(
mat1: Tensor,
mat2: Tensor,
mat_a: Tensor,
mat_b: Tensor,
offs: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
) -> Tensor:
torch._check(mat1.dim() == 2 or mat1.dim() == 3, lambda: "mat1 must be 2d or 3d")
torch._check(mat2.dim() == 2 or mat2.dim() == 3, lambda: "mat2 must be 2d or 3d")
torch._check(
(offs is not None) == (mat1.dim() == 2 or mat2.dim() == 2),
lambda: "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d",
return _meta_grouped_mm_common(
mat_a,
mat_b,
scale_a=None,
scale_b=None,
offs=offs,
bias=bias,
scale_result=None,
out_dtype=out_dtype,
)
if offs is not None:
torch._check(offs.dim() == 1, lambda: "offsets must be 1d")
out_dtype = out_dtype or mat1.dtype
torch._check(bias is None, lambda: "bias not supported yet")
out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs)
out = mat1.new_empty(out_size, dtype=out_dtype)
return out
@register_meta([aten._scaled_grouped_mm.default])
def meta_scaled_grouped_mm(
@ -7356,118 +7528,17 @@ def meta_scaled_grouped_mm(
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
):
# Check dimensions
torch._check(
mat_a.dim() == 2 or mat_a.dim() == 3, lambda: "mat_a has to be 2 or 3d"
return _meta_grouped_mm_common(
mat_a,
mat_b,
scale_a=scale_a,
scale_b=scale_b,
offs=offs,
bias=bias,
scale_result=scale_result,
out_dtype=out_dtype,
use_fast_accum=use_fast_accum,
)
torch._check(
mat_b.dim() == 2 or mat_b.dim() == 3, lambda: "mat_b has to be 2 or 3d"
)
a_is_2d = mat_a.dim() == 2
b_is_2d = mat_b.dim() == 2
# Check offsets
torch._check(
(offs is not None) == (a_is_2d or b_is_2d),
lambda: "Have to provide offsets if there is a 2d matrix",
)
if offs is not None:
torch._check(offs.dim() == 1, lambda: "offs has to be 1D")
torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32")
# Check matrix sizes
torch._check(
mat_a.size(-1) % 16 == 0,
lambda: f"Expected trailing dimension of mat_a to be divisible by 16 but got mat1 shape: {mat_a.size()}",
)
torch._check(
mat_b.size(-2) % 16 == 0 and mat_b.size(-1) % 16 == 0,
lambda: f"Expected mat_b shape to be divisible by 16 but got mat_b shape: {mat_b.size()}",
)
# Check scales
torch._check(
scale_a.dtype == torch.float and scale_b.dtype == torch.float,
lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
)
# Check scale dimensions
scale_multiplier = offs.size(0) if (a_is_2d and b_is_2d) else 1 # type: ignore[union-attr]
if a_is_2d:
torch._check(
scale_a.dim() == 1,
lambda: f"scale must be a 1D tensor for 2D mat_a, but got {scale_a.dim()}D",
)
torch._check(scale_a.is_contiguous(), lambda: "scale_a must be contiguous")
torch._check(
scale_a.size(0) == mat_a.size(0) * scale_multiplier,
lambda: "scale must have the same length as mat_a",
)
else:
torch._check(
scale_a.dim() == 2,
lambda: f"scale must be a 2D tensor for 3D mat_a, but got {scale_a.dim()}D",
)
torch._check(
scale_a.stride(1) == 1,
lambda: "scale_a must be contiguous in the last dimension",
)
torch._check(
scale_a.size(0) == mat_a.size(0),
lambda: "scale must have the same batch dimension as mat_a",
)
torch._check(
scale_a.size(1) == mat_a.size(1),
lambda: "scale must have the same first dimension as mat_a",
)
# Similar checks for scale_b
if b_is_2d:
torch._check(
scale_b.dim() == 1,
lambda: f"scale must be a 1D tensor for 2D mat_b, but got {scale_b.dim()}D",
)
torch._check(scale_b.is_contiguous(), lambda: "scale_b must be contiguous")
torch._check(
scale_b.size(0) == mat_b.size(1) * scale_multiplier,
lambda: "scale must have the same length as mat_b",
)
else:
torch._check(
scale_b.dim() == 2,
lambda: f"scale must be a 2D tensor for 3D mat_b, but got {scale_b.dim()}D",
)
torch._check(
scale_b.stride(1) == 1,
lambda: "scale_b must be contiguous in the last dimension",
)
torch._check(
scale_b.size(0) == mat_b.size(0),
lambda: "scale must have the same batch dimension as mat_b",
)
torch._check(
scale_b.size(1) == mat_b.size(2),
lambda: "scale must have the same last dimension as mat_b",
)
# Check bias
torch._check(bias is None, lambda: "Bias not supported yet")
# Check output dtype
out_dtype_ = out_dtype if out_dtype is not None else mat_a.dtype
torch._check(
out_dtype_ == torch.bfloat16,
lambda: "Only bf16 high precision output types are supported for grouped gemm",
)
# Compute output size
out_size = _compute_grouped_gemm_output_size(mat_a, mat_b, offs)
out = mat_a.new_empty(out_size, dtype=out_dtype)
return out
@register_meta(aten._softmax)

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import functools
import hashlib
from typing import Optional
@functools.lru_cache(None)
@ -50,9 +51,8 @@ def has_triton_tma_device():
):
# old API
try:
from triton.language.extra.cuda import ( # noqa: F401
experimental_device_tensormap_create1d,
experimental_device_tensormap_create2d,
from triton.language import ( # noqa: F401
_experimental_make_tensor_descriptor,
)
return True
@ -122,3 +122,20 @@ def triton_hash_with_backend():
# Hash is upper case so that it can't contain any Python keywords.
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
@functools.lru_cache(None)
def triton_set_allocator(device):
if has_triton_tma_device():
import torch
assert torch.cuda.current_device() == device
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device=device, dtype=torch.int8)
import triton
triton.set_allocator(alloc_fn)
return None