mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update auto-tuning support for _scaled_grouped_mm (#150944)
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K 6. Updated meta registration 7. Update synthetic offsets creation Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1339e88105
commit
09328eb02f
@ -1532,7 +1532,7 @@ namespace {
|
||||
"D, arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
|
||||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(dim) * scale_multiplier,
|
||||
"scale must have the same length as mat for arg ",
|
||||
@ -1545,8 +1545,8 @@ namespace {
|
||||
"D for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.stride(1),
|
||||
"scale_a must be contiguous in the last dimension for arg ",
|
||||
scale.stride(1) == 1,
|
||||
"scale must be contiguous in the last dimension for arg ",
|
||||
arg_idx);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == mat.size(0),
|
||||
@ -1610,6 +1610,7 @@ bool use_fast_accum) {
|
||||
|
||||
|
||||
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet");
|
||||
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
if (offs.has_value()) {
|
||||
|
||||
@ -1616,7 +1616,7 @@ class TestFP8Matmul(TestCase):
|
||||
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
||||
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
||||
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
||||
self.assertEqual(out, out_ref)
|
||||
self.assertEqual(out, out_ref, atol=1e-1, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
@ -1626,14 +1626,19 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
|
||||
device = "cuda"
|
||||
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
|
||||
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
|
||||
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
||||
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
||||
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
|
||||
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
|
||||
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
|
||||
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
|
||||
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f) if use_torch_compile else f
|
||||
f = torch.compile(
|
||||
f,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
}) if use_torch_compile else f
|
||||
out = f(a, b.t(), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
offs_cpu = offs.cpu()
|
||||
@ -1657,7 +1662,7 @@ class TestFP8Matmul(TestCase):
|
||||
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 16, 4
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
||||
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||
self.assertTrue(a.is_contiguous() is not strided)
|
||||
@ -1666,11 +1671,16 @@ class TestFP8Matmul(TestCase):
|
||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||
if check_zero_size:
|
||||
offs[0] = offs[1]
|
||||
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f, dynamic=False) if use_torch_compile else f
|
||||
f = torch.compile(
|
||||
f,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
}) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
|
||||
@ -1682,7 +1692,7 @@ class TestFP8Matmul(TestCase):
|
||||
ascalelist.append(scale_a[start:offs_cpu[i]])
|
||||
outlist.append(out[start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
|
||||
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@ -1694,16 +1704,21 @@ class TestFP8Matmul(TestCase):
|
||||
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 16, 4
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||
self.assertTrue(a.is_contiguous() is not strided)
|
||||
self.assertTrue(b.is_contiguous() is not strided)
|
||||
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
||||
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
||||
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f) if use_torch_compile else f
|
||||
f = torch.compile(
|
||||
f,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
}) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
|
||||
@ -1719,20 +1734,25 @@ class TestFP8Matmul(TestCase):
|
||||
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 16, 4
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
||||
self.assertTrue(a.is_contiguous() is not strided)
|
||||
self.assertTrue(b.is_contiguous() is not strided)
|
||||
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
||||
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
|
||||
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
||||
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
|
||||
for check_zero_size in (True, False):
|
||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||
if check_zero_size:
|
||||
offs[0] = offs[1]
|
||||
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f) if use_torch_compile else f
|
||||
f = torch.compile(
|
||||
f,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
}) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
offs_cpu = offs.cpu()
|
||||
@ -1743,7 +1763,7 @@ class TestFP8Matmul(TestCase):
|
||||
bscalelist.append(scale_b[start:offs_cpu[i]])
|
||||
outlist.append(out[:, start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
|
||||
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
|
||||
@ -217,7 +217,6 @@ def mark_nodes_dislike_padding(
|
||||
aten.convolution,
|
||||
aten.convolution_backward,
|
||||
aten._scaled_mm,
|
||||
aten._scaled_grouped_mm,
|
||||
]
|
||||
)
|
||||
# what's a better way to collect the reduction ops?
|
||||
|
||||
@ -38,7 +38,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
|
||||
|
||||
|
||||
@SymbolicGridFn
|
||||
def persistent_grouped_mm_grid(m, n, meta):
|
||||
def persistent_grouped_mm_grid(*args):
|
||||
meta = args[-1]
|
||||
return (meta["NUM_SMS"], 1, 1)
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
from ..ir import ChoiceCaller, Layout, TensorBox
|
||||
from ..lowering import register_lowering
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
@ -20,7 +19,7 @@ from ..select_algorithm import (
|
||||
from ..utils import (
|
||||
get_gpu_shared_memory,
|
||||
get_num_sms,
|
||||
get_tma_workspace_arg,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
)
|
||||
from .mm_common import (
|
||||
@ -52,7 +51,7 @@ _NV_CONFIGS = [
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
for block_size_m in [64, 128]
|
||||
for block_size_m in [16, 32, 64, 128]
|
||||
for block_size_n in [64, 128, 256]
|
||||
for block_size_k in [64, 128, 256]
|
||||
for num_stages in [3, 4]
|
||||
@ -81,11 +80,11 @@ _AMD_CONFIGS = [
|
||||
]
|
||||
|
||||
|
||||
def scaled_grouped_mm_configs():
|
||||
def grouped_mm_configs():
|
||||
return _AMD_CONFIGS if torch.version.hip else _NV_CONFIGS
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args):
|
||||
def early_config_prune(g, m, configs, named_args):
|
||||
dtsize = 1
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
@ -98,14 +97,26 @@ def early_config_prune(configs, named_args):
|
||||
config.num_warps,
|
||||
getattr(config, "num_consumer_groups", 0),
|
||||
)
|
||||
G, M, N, K = (
|
||||
named_args["G"],
|
||||
named_args["M_BUCKET"],
|
||||
named_args["N"],
|
||||
named_args["K"],
|
||||
)
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
# 1. Prune NV configs depending on g and m.
|
||||
if not torch.version.hip:
|
||||
if not has_free_symbols((g, m)):
|
||||
a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"]
|
||||
m_avg = m // g if a_is_2d and not b_is_2d else m
|
||||
if m_avg <= 16:
|
||||
if BLOCK_M > 32:
|
||||
continue
|
||||
elif m_avg <= 32:
|
||||
if BLOCK_M > 64:
|
||||
continue
|
||||
elif m_avg <= 64:
|
||||
if BLOCK_M <= 16:
|
||||
continue
|
||||
else:
|
||||
if BLOCK_M <= 32:
|
||||
continue
|
||||
|
||||
# 2. make sure we have enough smem
|
||||
max_shared_memory = get_gpu_shared_memory()
|
||||
|
||||
if torch.version.hip:
|
||||
@ -117,39 +128,7 @@ def early_config_prune(configs, named_args):
|
||||
|
||||
use_warp_specialization = num_consumer_groups >= 1
|
||||
|
||||
M_PER_GROUP = M // G
|
||||
MIN_M_TILES = 32 if torch.version.hip else 64
|
||||
# 2. make sure we don't load M tiles that are too big
|
||||
if (
|
||||
not use_warp_specialization
|
||||
and BLOCK_M > MIN_M_TILES
|
||||
and BLOCK_M > (M_PER_GROUP * 2)
|
||||
):
|
||||
continue
|
||||
# 3. make sure we don't load N tiles that are too small
|
||||
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
||||
continue
|
||||
|
||||
num_sm = get_num_sms()
|
||||
|
||||
N_TILES = N // BLOCK_N
|
||||
MIN_N_TILES = 32 if torch.version.hip else 64
|
||||
# 4. make sure we don't load N tiles that are too big
|
||||
if (
|
||||
not use_warp_specialization
|
||||
and BLOCK_N > MIN_N_TILES
|
||||
and M * N_TILES < num_sm
|
||||
):
|
||||
continue
|
||||
# 5. make sure we don't load N tiles that are too small
|
||||
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
||||
continue
|
||||
|
||||
# 6. make sure K can be evenly divided
|
||||
if K % BLOCK_K != 0:
|
||||
continue
|
||||
|
||||
# 7. make sure we can partition for ws
|
||||
# 3. make sure we can partition for ws
|
||||
if use_warp_specialization:
|
||||
if num_warps != 4:
|
||||
continue
|
||||
@ -166,47 +145,129 @@ def early_config_prune(configs, named_args):
|
||||
|
||||
|
||||
# Copied from fbgemm grouped_gemm.py
|
||||
triton_scaled_grouped_mm_source = r"""
|
||||
{{def_kernel("a_ptr", "b_ptr", "a_scale_ptr", "b_scale_ptr", "m_sizes")}}
|
||||
triton_grouped_mm_source = r"""
|
||||
{%- if A_IS_2D or B_IS_2D %}
|
||||
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}}
|
||||
{%- else %}
|
||||
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}}
|
||||
{%- endif %}
|
||||
tidx = tl.program_id(0)
|
||||
|
||||
dtype = tl.float8e4nv
|
||||
TMA_SIZE: tl.constexpr = tl.constexpr(128)
|
||||
{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %}
|
||||
{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %}
|
||||
{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %}
|
||||
|
||||
workspace_base = ws_ptr + tidx * 2 * TMA_SIZE
|
||||
c_desc_ptr = None
|
||||
{%- if A_IS_2D %}
|
||||
{%- if B_IS_2D %}
|
||||
G = {{size("offsets_ptr", 0)}}
|
||||
{%- else %}
|
||||
G = {{size("b_ptr", 0)}}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{%- if B_IS_2D %}
|
||||
G = {{size("a_ptr", 0)}}
|
||||
{%- else %}
|
||||
G = {{size("a_ptr", 0)}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
# the b_ptr tensor is given with its last two dims transposed, revert here
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=a_ptr,
|
||||
load_size=[BLOCK_M, BLOCK_K],
|
||||
global_size=[M, K],
|
||||
element_ty=a_ptr.dtype.element_ty,
|
||||
M = {{size("a_ptr", -2)}}
|
||||
N = {{size("b_ptr", -1)}}
|
||||
K = {{size("a_ptr", -1)}}
|
||||
|
||||
A_STRIDE_M = {{stride("a_ptr", -2)}}
|
||||
A_STRIDE_K = {{stride("a_ptr", -1)}}
|
||||
{%- if not A_IS_2D %}
|
||||
A_STRIDE_G = {{stride("a_ptr", 0)}}
|
||||
SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}}
|
||||
{%- endif %}
|
||||
B_STRIDE_N = {{stride("b_ptr", -1)}}
|
||||
B_STRIDE_K = {{stride("b_ptr", -2)}}
|
||||
{%- if not B_IS_2D %}
|
||||
B_STRIDE_G = {{stride("b_ptr", 0)}}
|
||||
SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}}
|
||||
{%- endif %}
|
||||
|
||||
# fixme: a_desc = tl.make_tensor_descriptor(
|
||||
a_desc = tl._experimental_make_tensor_descriptor(
|
||||
a_ptr,
|
||||
{%- if A_IS_2D %}
|
||||
shape=[M, K],
|
||||
# fixme: strides=[A_STRIDE_M, A_STRIDE_K],
|
||||
strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
|
||||
block_shape=[BLOCK_M, BLOCK_K],
|
||||
{%- else %}
|
||||
shape=[G, M, K],
|
||||
# fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
|
||||
strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
|
||||
block_shape=[1, BLOCK_M, BLOCK_K],
|
||||
{%- endif %}
|
||||
)
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=b_desc_ptr,
|
||||
global_address=b_ptr,
|
||||
load_size=[BLOCK_N, BLOCK_K],
|
||||
global_size=[N * G, K],
|
||||
element_ty=b_ptr.dtype.element_ty,
|
||||
)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
|
||||
M_end_offset = 0
|
||||
# fixme: b_desc = tl.make_tensor_descriptor(
|
||||
b_desc = tl._experimental_make_tensor_descriptor(
|
||||
b_ptr,
|
||||
{%- if B_IS_2D %}
|
||||
shape=[N, K],
|
||||
# fixme: strides=[B_STRIDE_N, B_STRIDE_K],
|
||||
strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
|
||||
block_shape=[BLOCK_N, BLOCK_K],
|
||||
{%- else %}
|
||||
shape=[G, N, K],
|
||||
# fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
|
||||
strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
|
||||
block_shape=[1, BLOCK_N, BLOCK_K],
|
||||
{%- endif %}
|
||||
)
|
||||
|
||||
{%- if M_IS_VARYING %}
|
||||
m_end_offset = 0
|
||||
{%- endif %}
|
||||
{%- if N_IS_VARYING %}
|
||||
n_end_offset = 0
|
||||
{%- endif %}
|
||||
{%- if K_IS_VARYING %}
|
||||
k_end_offset = 0
|
||||
{%- endif %}
|
||||
iterated_tiles = 0
|
||||
for g in tl.range(G):
|
||||
{%- if M_IS_VARYING %}
|
||||
# Move across groups
|
||||
M_start_offset = M_end_offset
|
||||
M_end_offset = tl.load(m_sizes + g)
|
||||
m_size = M_end_offset - M_start_offset
|
||||
m_start_offset = m_end_offset
|
||||
m_end_offset = tl.load(offsets_ptr + g)
|
||||
m_size = m_end_offset - m_start_offset
|
||||
m_scale_start_offset = m_start_offset
|
||||
{%- else %}
|
||||
m_start_offset = 0
|
||||
m_size = M
|
||||
m_scale_start_offset = g * M
|
||||
{%- endif %}
|
||||
|
||||
{%- if N_IS_VARYING %}
|
||||
# Move across groups
|
||||
n_start_offset = n_end_offset
|
||||
n_end_offset = tl.load(offsets_ptr + g)
|
||||
n_size = n_end_offset - n_start_offset
|
||||
n_scale_start_offset = n_start_offset
|
||||
{%- else %}
|
||||
n_start_offset = 0
|
||||
n_size = N
|
||||
n_scale_start_offset = g * N
|
||||
{%- endif %}
|
||||
|
||||
if m_size > 0 and n_size > 0:
|
||||
{%- if K_IS_VARYING %}
|
||||
# Move across groups
|
||||
k_start_offset = k_end_offset
|
||||
k_end_offset = tl.load(offsets_ptr + g)
|
||||
k_size = k_end_offset - k_start_offset
|
||||
{%- else %}
|
||||
k_start_offset = 0
|
||||
k_size = K
|
||||
{%- endif %}
|
||||
|
||||
if m_size > 0:
|
||||
N_start_offset = g.to(tl.int64) * N
|
||||
n_size = N
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
|
||||
num_n_tiles = tl.cdiv(n_size, BLOCK_N)
|
||||
num_tiles = num_m_tiles * num_n_tiles
|
||||
@ -219,64 +280,111 @@ triton_scaled_grouped_mm_source = r"""
|
||||
tile_n_idx = gidx // num_m_tiles
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
tl.static_assert(K % BLOCK_K == 0)
|
||||
if USE_TMA_LOAD:
|
||||
m_offset = (M_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
|
||||
n_offset = (N_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
|
||||
for k_offset in range(0, K, BLOCK_K):
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr,
|
||||
[m_offset, k_offset],
|
||||
[BLOCK_M, BLOCK_K],
|
||||
dtype,
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr,
|
||||
[n_offset, k_offset],
|
||||
[BLOCK_N, BLOCK_K],
|
||||
dtype,
|
||||
)
|
||||
if USE_FAST_ACCUM:
|
||||
accumulator = tl.dot(a, b.T, accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b.T)
|
||||
else:
|
||||
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
a_ptrs = (
|
||||
a_desc_ptr
|
||||
+ (M_start_offset + offs_am[:, None]) * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
b_ptrs = (
|
||||
b_desc_ptr
|
||||
+ (N_start_offset + offs_bn[:, None]) * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
for k_offset in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
||||
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
||||
accumulator += tl.dot(a, b.T)
|
||||
a_ptrs += BLOCK_K
|
||||
b_ptrs += BLOCK_K
|
||||
|
||||
{%- if USE_TMA_LOAD %}
|
||||
m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
|
||||
n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
|
||||
|
||||
for k_offset in range(0, k_size, BLOCK_K):
|
||||
{%- if A_IS_2D %}
|
||||
a = a_desc.load([m_offset, k_start_offset + k_offset])
|
||||
{%- else %}
|
||||
a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K)
|
||||
{%- endif %}
|
||||
{%- if B_IS_2D %}
|
||||
b = b_desc.load([n_offset, k_start_offset + k_offset])
|
||||
{%- else %}
|
||||
b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K)
|
||||
{%- endif %}
|
||||
|
||||
{%- if K_IS_VARYING %}
|
||||
if k_offset + BLOCK_K > k_size:
|
||||
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
|
||||
a = tl.where(group_offs_k < k_size, a, 0)
|
||||
b = tl.where(group_offs_k < k_size, b, 0)
|
||||
{%- endif %}
|
||||
|
||||
{%- if USE_FAST_ACCUM %}
|
||||
accumulator = tl.dot(a, b.T, accumulator)
|
||||
{%- else %}
|
||||
accumulator += tl.dot(a, b.T)
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_k = k_start_offset + tl.arange(0, BLOCK_K)
|
||||
a_ptrs = (
|
||||
a_ptr
|
||||
{%- if not A_IS_2D %}
|
||||
+ g * A_STRIDE_G
|
||||
{%- endif %}
|
||||
+ (m_start_offset + offs_am[:, None]) * A_STRIDE_M
|
||||
+ offs_k[None, :] * A_STRIDE_K
|
||||
)
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
{%- if not B_IS_2D %}
|
||||
+ g * B_STRIDE_G
|
||||
{%- endif %}
|
||||
+ (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
|
||||
+ offs_k[None, :] * B_STRIDE_K
|
||||
)
|
||||
for k_offset in range(0, k_size, BLOCK_K):
|
||||
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
||||
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
||||
if k_offset + BLOCK_K > k_size:
|
||||
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
|
||||
a = tl.where(group_offs_k < k_size, a, 0)
|
||||
b = tl.where(group_offs_k < k_size, b, 0)
|
||||
{%- if USE_FAST_ACCUM %}
|
||||
accumulator = tl.dot(a, b.T, accumulator)
|
||||
{%- else %}
|
||||
accumulator += tl.dot(a, b.T)
|
||||
{%- endif %}
|
||||
a_ptrs += BLOCK_K
|
||||
b_ptrs += BLOCK_K
|
||||
{%- endif %}
|
||||
|
||||
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
a_scale = tl.load(
|
||||
a_scale_ptr + M_start_offset + offs_am[:, None],
|
||||
scale_a = tl.load(
|
||||
scale_a_ptr
|
||||
{%- if A_IS_2D %}
|
||||
+ m_scale_start_offset
|
||||
{%- else %}
|
||||
+ g * SCALE_A_STRIDE_G
|
||||
{%- endif %}
|
||||
+ offs_am[:, None],
|
||||
mask=offs_am[:, None] < m_size,
|
||||
)
|
||||
b_scale = tl.load(
|
||||
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
||||
scale_b = tl.load(
|
||||
scale_b_ptr
|
||||
{%- if B_IS_2D %}
|
||||
+ n_scale_start_offset
|
||||
{%- else %}
|
||||
+ g * SCALE_B_STRIDE_G
|
||||
{%- endif %}
|
||||
+ offs_bn[None, :],
|
||||
mask=offs_bn[None, :] < n_size,
|
||||
)
|
||||
c = accumulator.to(tl.float32) * a_scale * b_scale
|
||||
c = accumulator.to(tl.float32) * scale_a * scale_b
|
||||
|
||||
idx_m = (M_start_offset + offs_am[:, None])
|
||||
{%- if M_IS_VARYING %}
|
||||
idx_m = (m_start_offset + offs_am[:, None])
|
||||
{%- else %}
|
||||
idx_m = offs_am[:, None]
|
||||
{%- endif %}
|
||||
{%- if N_IS_VARYING %}
|
||||
idx_n = (n_start_offset + offs_bn[None, :])
|
||||
{%- else %}
|
||||
idx_n = offs_bn[None, :]
|
||||
{%- endif %}
|
||||
mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size
|
||||
{%- if M_IS_VARYING or N_IS_VARYING %}
|
||||
{{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}}
|
||||
{%- else %}
|
||||
{{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16)}}
|
||||
{%- endif %}
|
||||
tidx += NUM_SMS
|
||||
|
||||
iterated_tiles += num_tiles
|
||||
@ -286,7 +394,7 @@ triton_scaled_grouped_mm_source = r"""
|
||||
triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
name="scaled_grouped_mm",
|
||||
grid=persistent_grouped_mm_grid,
|
||||
source=triton_scaled_grouped_mm_source,
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
|
||||
@ -297,7 +405,9 @@ def grouped_mm_args(
|
||||
layout=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
mat1, mat2, offs = realize_inputs(mat1, mat2, offs)
|
||||
mat1, mat2 = realize_inputs(mat1, mat2)
|
||||
if offs is not None:
|
||||
realize_inputs(offs)
|
||||
mat1_size = mat1.get_size()
|
||||
mat2_size = mat2.get_size()
|
||||
|
||||
@ -348,46 +458,171 @@ def can_use_triton_kernel(
|
||||
mat_b: TensorBox,
|
||||
offs: Optional[TensorBox],
|
||||
bias: Optional[TensorBox],
|
||||
scale_result: Optional[TensorBox],
|
||||
) -> bool:
|
||||
a_shape = mat_a.get_size()
|
||||
b_shape = mat_b.get_size()
|
||||
a_stride = mat_a.get_stride()
|
||||
b_stride = mat_b.get_stride()
|
||||
if not has_triton_tma_device():
|
||||
return False
|
||||
|
||||
# A must be contiguous 2d
|
||||
a_layout_ok = (
|
||||
len(a_shape) == 2
|
||||
and a_stride[1] == 1
|
||||
and a_stride[0] == a_shape[1]
|
||||
and a_shape[1] >= 32
|
||||
)
|
||||
# The _grouped_mm()/_scaled_grouped_mm() operator do not support
|
||||
# bias nor scale_result yet.
|
||||
if bias is not None:
|
||||
return False
|
||||
if scale_result is not None:
|
||||
return False
|
||||
|
||||
# B must be contiguous 3d with transposed last dimension
|
||||
b_layout_ok = (
|
||||
len(b_shape) == 3
|
||||
and b_stride[2] == b_shape[1]
|
||||
and b_stride[1] == 1
|
||||
and b_stride[0] == (b_shape[1] * b_shape[2])
|
||||
and b_shape[1] >= 32
|
||||
)
|
||||
|
||||
return (
|
||||
offs is not None
|
||||
and bias is None
|
||||
and has_triton_tma_device()
|
||||
and a_layout_ok
|
||||
and b_layout_ok
|
||||
)
|
||||
if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2:
|
||||
return offs is not None
|
||||
else:
|
||||
return offs is None
|
||||
|
||||
|
||||
def create_offsets(x, m1_size, m2_size, offs_size):
|
||||
assert len(m1_size) == 2 and len(m2_size) == 3, (
|
||||
"Autotuning _scaled_grouped_mm is only implemented for 2d-3d tensors"
|
||||
m1_is_2d = len(m1_size) == 2
|
||||
m2_is_2d = len(m2_size) == 2
|
||||
if m1_is_2d:
|
||||
if m2_is_2d:
|
||||
k = V.graph.sizevars.size_hint(m1_size[1])
|
||||
noffs = V.graph.sizevars.size_hint(offs_size[0])
|
||||
step = k / noffs
|
||||
return torch.linspace(
|
||||
step, k, noffs, dtype=x.get_dtype(), device=x.get_device()
|
||||
)
|
||||
|
||||
else:
|
||||
m = V.graph.sizevars.size_hint(m1_size[0])
|
||||
noffs = V.graph.sizevars.size_hint(offs_size[0])
|
||||
step = m / noffs
|
||||
return torch.linspace(
|
||||
step, m, noffs, dtype=x.get_dtype(), device=x.get_device()
|
||||
)
|
||||
else:
|
||||
if m2_is_2d:
|
||||
n = V.graph.sizevars.size_hint(m2_size[0])
|
||||
noffs = V.graph.sizevars.size_hint(offs_size[0])
|
||||
step = n / noffs
|
||||
return torch.linspace(
|
||||
step, n, noffs, dtype=x.get_dtype(), device=x.get_device()
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _tuned_grouped_mm_common(
|
||||
operator_name: str,
|
||||
algorithm_name: str,
|
||||
extern_kernel_choice: ExternKernelChoice,
|
||||
kernel_template: TritonTemplate,
|
||||
mat_a: TensorBox,
|
||||
mat_b: TensorBox,
|
||||
scale_a: Optional[TensorBox] = None,
|
||||
scale_b: Optional[TensorBox] = None,
|
||||
offs: Optional[TensorBox] = None,
|
||||
bias: Optional[TensorBox] = None,
|
||||
scale_result: Optional[TensorBox] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: Optional[bool] = False,
|
||||
layout: Optional[Layout] = None,
|
||||
) -> TensorBox:
|
||||
assert (scale_a is None) == (scale_b is None)
|
||||
assert scale_result is None or scale_a is not None
|
||||
|
||||
m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args(
|
||||
mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
counters["aten_mm_info"][operator_name] += 1
|
||||
log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s"
|
||||
log.info(
|
||||
log_message,
|
||||
m1_size,
|
||||
m2_size,
|
||||
mat_a.get_dtype(),
|
||||
mat_b.get_dtype(),
|
||||
layout,
|
||||
)
|
||||
check_supported_striding(mat_a, mat_b)
|
||||
|
||||
# workaround for Inductor not supporting optional tensor input arguments
|
||||
input_nodes: list[Any] = [mat_a, mat_b]
|
||||
if scale_a is not None:
|
||||
input_nodes.append(realize_inputs(scale_a))
|
||||
if scale_b is not None:
|
||||
input_nodes.append(realize_inputs(scale_b))
|
||||
if offs is not None:
|
||||
input_nodes.append(realize_inputs(offs))
|
||||
|
||||
aten_choice = extern_kernel_choice.bind(
|
||||
input_nodes,
|
||||
layout,
|
||||
out_dtype=out_dtype,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
choices.append(aten_choice)
|
||||
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
# Checking only for the equality of correspoding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result):
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.guard_equals(g1, g2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.guard_equals(g1, g2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.guard_equals(g1, g2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
kwargs = {
|
||||
"A_IS_2D": a_is_2d,
|
||||
"B_IS_2D": b_is_2d,
|
||||
"USE_FAST_ACCUM": use_fast_accum,
|
||||
"NUM_SMS": get_num_sms(),
|
||||
"USE_TMA_LOAD": True,
|
||||
}
|
||||
|
||||
for config in early_config_prune(g, m, grouped_mm_configs(), kwargs):
|
||||
kernel_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
),
|
||||
}
|
||||
return autotune_select_algorithm(
|
||||
algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns
|
||||
)
|
||||
m = V.graph.sizevars.size_hint(m1_size[0])
|
||||
noffs = V.graph.sizevars.size_hint(offs_size[0])
|
||||
step = m / noffs
|
||||
return torch.linspace(step, m, noffs, dtype=x.get_dtype(), device=x.get_device())
|
||||
|
||||
|
||||
@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None)
|
||||
@ -403,77 +638,21 @@ def tuned_scaled_grouped_mm(
|
||||
use_fast_accum: bool = False,
|
||||
layout: Optional[Layout] = None,
|
||||
) -> TensorBox:
|
||||
m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args(
|
||||
mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
"""Auto-tuning for _scaled_grouped_mm() operator."""
|
||||
|
||||
counters["aten_mm_info"]["aten._scaled_grouped_mm.default"] += 1
|
||||
log.info(
|
||||
"Tuned aten._scaled_grouped_mm.default: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
|
||||
m1_size,
|
||||
m2_size,
|
||||
mat_a.get_dtype(),
|
||||
mat_b.get_dtype(),
|
||||
return _tuned_grouped_mm_common(
|
||||
"aten._scaled_grouped_mm.default",
|
||||
"scaled_grouped_mm",
|
||||
aten__scaled_grouped_mm,
|
||||
triton_scaled_grouped_mm_template,
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
bias,
|
||||
scale_result,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
layout,
|
||||
)
|
||||
|
||||
check_supported_striding(mat_a, mat_b)
|
||||
|
||||
scale_a, scale_b = realize_inputs(scale_a, scale_b)
|
||||
|
||||
# workaround for Inductor not supporting optional tensor input arguments
|
||||
input_nodes: list[Any] = [mat_a, mat_b, scale_a, scale_b]
|
||||
if offs is not None:
|
||||
input_nodes.append(realize_inputs(offs))
|
||||
if bias is not None:
|
||||
input_nodes.append(realize_inputs(bias))
|
||||
|
||||
aten_choice = aten__scaled_grouped_mm.bind(
|
||||
input_nodes,
|
||||
layout,
|
||||
out_dtype=out_dtype,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
choices.append(aten_choice)
|
||||
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias):
|
||||
m, k1 = m1_size
|
||||
g, k2, n = m2_size
|
||||
k = V.graph.sizevars.guard_equals(k1, k2)
|
||||
kwargs = {
|
||||
"G": g,
|
||||
"M": m,
|
||||
"M_BUCKET": next_power_of_2(m),
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_SMS": get_num_sms(),
|
||||
"USE_TMA_LOAD": True,
|
||||
"USE_TMA_STORE": False,
|
||||
"USE_FAST_ACCUM": use_fast_accum,
|
||||
}
|
||||
for config in early_config_prune(scaled_grouped_mm_configs(), kwargs):
|
||||
triton_scaled_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
workspace_arg=get_tma_workspace_arg(
|
||||
num_tma_descriptors=2,
|
||||
device=mat_a.get_device(),
|
||||
),
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(x, m1_size, m2_size, offs.get_size()),
|
||||
}
|
||||
return autotune_select_algorithm(
|
||||
"scaled_grouped_mm", choices, input_nodes, layout, input_gen_fns=input_gen_fns
|
||||
)
|
||||
|
||||
@ -33,6 +33,7 @@ import torch
|
||||
from torch._dynamo.utils import set_feature_use
|
||||
from torch._prims_common import compute_required_storage_length
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._triton import triton_set_allocator
|
||||
|
||||
from ..triton_bundler import TritonBundler
|
||||
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
|
||||
@ -1114,6 +1115,8 @@ class CachingAutotuner(KernelInterface):
|
||||
**self.configs[0].kwargs,
|
||||
)
|
||||
|
||||
triton_set_allocator(self.triton_meta["device"])
|
||||
|
||||
if len(self.launchers) != 1:
|
||||
if len(self.launchers) == 0:
|
||||
start_time = time.time_ns()
|
||||
|
||||
@ -7292,7 +7292,7 @@ def sigmoid(self: Tensor) -> Tensor:
|
||||
return torch.empty_like(self, dtype=result_dtype)
|
||||
|
||||
|
||||
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
|
||||
def _compute_grouped_mm_output_size(mat1, mat2, offs):
|
||||
mat1_is_2d = mat1.dim() == 2
|
||||
mat2_is_2d = mat2.dim() == 2
|
||||
|
||||
@ -7316,33 +7316,205 @@ def _compute_grouped_gemm_output_size(mat1, mat2, offs):
|
||||
return mat1.size(0), mat1.size(1), mat2.size(-1)
|
||||
|
||||
|
||||
def _meta_grouped_mm_common(
|
||||
mat_a: Tensor,
|
||||
mat_b: Tensor,
|
||||
scale_a: Optional[torch.Tensor],
|
||||
scale_b: Optional[torch.Tensor],
|
||||
offs: Optional[Tensor] = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
scale_result: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
):
|
||||
torch._check(
|
||||
(scale_a is None) == (scale_b is None),
|
||||
lambda: "Either both scale factors are given, or none",
|
||||
)
|
||||
scaled = scale_a is not None and scale_b is not None
|
||||
|
||||
# Implementing all the checks from
|
||||
# _grouped_mm_cuda()/_scaled_grouped_mm_cuda() code in
|
||||
# aten/src/ATen/native/cuda/Blas.cpp.
|
||||
|
||||
if scaled:
|
||||
torch._check(
|
||||
mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn,
|
||||
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
|
||||
lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
|
||||
)
|
||||
|
||||
torch._check(
|
||||
mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
|
||||
lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}",
|
||||
)
|
||||
|
||||
mat_a_is_2d = mat_a.dim() == 2
|
||||
mat_b_is_2d = mat_b.dim() == 2
|
||||
|
||||
torch._check(
|
||||
mat_a.shape[-1] % 16 == 0,
|
||||
lambda: f"Expected mat_a.shape[-1] to be divisible by 16, but got mat_a.shape[-1]={mat_a.shape[1]}",
|
||||
)
|
||||
torch._check(
|
||||
mat_b.shape[-2] % 16 == 0 and mat_b.shape[-1] % 16 == 0,
|
||||
lambda: f"Expected mat_b.shape[-2] and mat_b.shape[-1] to be both divisble by 16 but got {mat_b.shape[-2]} and {mat_b.shape[-1]}", # noqa: B950
|
||||
)
|
||||
|
||||
if scaled:
|
||||
|
||||
def is_row_major(mat):
|
||||
mat_stride = mat.stride()
|
||||
return mat_stride[-2] > 1 and mat_stride[-1] == 1
|
||||
|
||||
def is_col_major(mat):
|
||||
mat_stride = mat.stride()
|
||||
return mat_stride[-2] == 1 and mat_stride[-1] > 1
|
||||
|
||||
torch._check(
|
||||
is_row_major(mat_a),
|
||||
lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}",
|
||||
)
|
||||
torch._check(
|
||||
is_col_major(mat_b),
|
||||
lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}",
|
||||
)
|
||||
|
||||
def check_valid_strides(mat_name, mat):
|
||||
end_dim = mat.dim() - 1
|
||||
alignment = 16 / mat.element_size()
|
||||
mat_stride = mat.stride()
|
||||
if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
|
||||
1, mat.shape[end_dim - 1]
|
||||
):
|
||||
torch._check(
|
||||
mat_stride[end_dim] % alignment == 0,
|
||||
lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.",
|
||||
)
|
||||
elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
|
||||
1, mat.shape[end_dim]
|
||||
):
|
||||
torch._check(
|
||||
mat_stride[end_dim - 1] % alignment == 0,
|
||||
lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
False,
|
||||
lambda: f"Expected {mat_name} to have a contiguous dimension and not be mat_a-overlapping, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
|
||||
)
|
||||
|
||||
check_valid_strides("mat_a", mat_a)
|
||||
check_valid_strides("mat_b", mat_b)
|
||||
|
||||
if scale_a is not None and scale_b is not None:
|
||||
torch._check(
|
||||
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
|
||||
lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
|
||||
)
|
||||
|
||||
def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
|
||||
if mat.dim() == 2:
|
||||
torch._check(
|
||||
scale.dim() == 1,
|
||||
lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.is_contiguous(),
|
||||
lambda: f"Expected {scale_name} to be contiguous.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
|
||||
lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale.dim() == 2,
|
||||
lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
|
||||
)
|
||||
torch._check(
|
||||
scale.stride(1) == 1,
|
||||
lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[0] == mat.shape[0],
|
||||
lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
|
||||
)
|
||||
torch._check(
|
||||
scale.shape[1] == mat.shape[1 + scaled_dim],
|
||||
lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.",
|
||||
)
|
||||
|
||||
scale_multiplier = (
|
||||
offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1
|
||||
)
|
||||
check_scale("scale_a", scale_a, mat_a, 0, scale_multiplier)
|
||||
check_scale("scale_b", scale_b, mat_b, 1, scale_multiplier)
|
||||
|
||||
torch._check(
|
||||
scale_result is None,
|
||||
lambda: "Scale result tensor provided, but it is not supported yet.",
|
||||
)
|
||||
|
||||
if mat_a_is_2d or mat_b_is_2d:
|
||||
torch._check(
|
||||
offs is not None,
|
||||
lambda: f"Offsets tensor not provided, but is needed for {mat_a.dim()}D/{mat_b.dim()}D multiplicand layouts.",
|
||||
)
|
||||
if offs is not None: # to silence Mypy
|
||||
torch._check(
|
||||
offs.dim() == 1,
|
||||
lambda: f"Offsets tensor must be 1D, but got offs.dim()={offs.dim()}.",
|
||||
)
|
||||
torch._check(
|
||||
offs.dtype == torch.int32,
|
||||
lambda: f"Offsets tensor must be integer (int32) tensor, but got {offs.dtype}.",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
offs is None,
|
||||
lambda: "Offsets tensor provided, but is not needed for 3D/3D multiplicand layouts.",
|
||||
)
|
||||
|
||||
torch._check(
|
||||
bias is None,
|
||||
lambda: "Bias tensor provided, but it is not supported yet.",
|
||||
)
|
||||
|
||||
torch._check(
|
||||
out_dtype is None or out_dtype == torch.bfloat16,
|
||||
lambda: "If output dtype provided, it must be torch.bfloat16.",
|
||||
)
|
||||
|
||||
out_size = _compute_grouped_mm_output_size(mat_a, mat_b, offs)
|
||||
out_dtype = out_dtype or mat_a.dtype
|
||||
return torch.empty(out_size, dtype=out_dtype, device=mat_a.device)
|
||||
|
||||
|
||||
@register_meta(aten._grouped_mm)
|
||||
@out_wrapper()
|
||||
def grouped_mm(
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
mat_a: Tensor,
|
||||
mat_b: Tensor,
|
||||
offs: Optional[Tensor] = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tensor:
|
||||
torch._check(mat1.dim() == 2 or mat1.dim() == 3, lambda: "mat1 must be 2d or 3d")
|
||||
torch._check(mat2.dim() == 2 or mat2.dim() == 3, lambda: "mat2 must be 2d or 3d")
|
||||
torch._check(
|
||||
(offs is not None) == (mat1.dim() == 2 or mat2.dim() == 2),
|
||||
lambda: "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d",
|
||||
return _meta_grouped_mm_common(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a=None,
|
||||
scale_b=None,
|
||||
offs=offs,
|
||||
bias=bias,
|
||||
scale_result=None,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if offs is not None:
|
||||
torch._check(offs.dim() == 1, lambda: "offsets must be 1d")
|
||||
|
||||
out_dtype = out_dtype or mat1.dtype
|
||||
torch._check(bias is None, lambda: "bias not supported yet")
|
||||
|
||||
out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs)
|
||||
out = mat1.new_empty(out_size, dtype=out_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_meta([aten._scaled_grouped_mm.default])
|
||||
def meta_scaled_grouped_mm(
|
||||
@ -7356,118 +7528,17 @@ def meta_scaled_grouped_mm(
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
):
|
||||
# Check dimensions
|
||||
torch._check(
|
||||
mat_a.dim() == 2 or mat_a.dim() == 3, lambda: "mat_a has to be 2 or 3d"
|
||||
return _meta_grouped_mm_common(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
offs=offs,
|
||||
bias=bias,
|
||||
scale_result=scale_result,
|
||||
out_dtype=out_dtype,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
torch._check(
|
||||
mat_b.dim() == 2 or mat_b.dim() == 3, lambda: "mat_b has to be 2 or 3d"
|
||||
)
|
||||
|
||||
a_is_2d = mat_a.dim() == 2
|
||||
b_is_2d = mat_b.dim() == 2
|
||||
|
||||
# Check offsets
|
||||
torch._check(
|
||||
(offs is not None) == (a_is_2d or b_is_2d),
|
||||
lambda: "Have to provide offsets if there is a 2d matrix",
|
||||
)
|
||||
|
||||
if offs is not None:
|
||||
torch._check(offs.dim() == 1, lambda: "offs has to be 1D")
|
||||
torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32")
|
||||
|
||||
# Check matrix sizes
|
||||
torch._check(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
lambda: f"Expected trailing dimension of mat_a to be divisible by 16 but got mat1 shape: {mat_a.size()}",
|
||||
)
|
||||
torch._check(
|
||||
mat_b.size(-2) % 16 == 0 and mat_b.size(-1) % 16 == 0,
|
||||
lambda: f"Expected mat_b shape to be divisible by 16 but got mat_b shape: {mat_b.size()}",
|
||||
)
|
||||
|
||||
# Check scales
|
||||
torch._check(
|
||||
scale_a.dtype == torch.float and scale_b.dtype == torch.float,
|
||||
lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
|
||||
)
|
||||
|
||||
# Check scale dimensions
|
||||
scale_multiplier = offs.size(0) if (a_is_2d and b_is_2d) else 1 # type: ignore[union-attr]
|
||||
|
||||
if a_is_2d:
|
||||
torch._check(
|
||||
scale_a.dim() == 1,
|
||||
lambda: f"scale must be a 1D tensor for 2D mat_a, but got {scale_a.dim()}D",
|
||||
)
|
||||
torch._check(scale_a.is_contiguous(), lambda: "scale_a must be contiguous")
|
||||
torch._check(
|
||||
scale_a.size(0) == mat_a.size(0) * scale_multiplier,
|
||||
lambda: "scale must have the same length as mat_a",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale_a.dim() == 2,
|
||||
lambda: f"scale must be a 2D tensor for 3D mat_a, but got {scale_a.dim()}D",
|
||||
)
|
||||
torch._check(
|
||||
scale_a.stride(1) == 1,
|
||||
lambda: "scale_a must be contiguous in the last dimension",
|
||||
)
|
||||
torch._check(
|
||||
scale_a.size(0) == mat_a.size(0),
|
||||
lambda: "scale must have the same batch dimension as mat_a",
|
||||
)
|
||||
torch._check(
|
||||
scale_a.size(1) == mat_a.size(1),
|
||||
lambda: "scale must have the same first dimension as mat_a",
|
||||
)
|
||||
|
||||
# Similar checks for scale_b
|
||||
if b_is_2d:
|
||||
torch._check(
|
||||
scale_b.dim() == 1,
|
||||
lambda: f"scale must be a 1D tensor for 2D mat_b, but got {scale_b.dim()}D",
|
||||
)
|
||||
torch._check(scale_b.is_contiguous(), lambda: "scale_b must be contiguous")
|
||||
torch._check(
|
||||
scale_b.size(0) == mat_b.size(1) * scale_multiplier,
|
||||
lambda: "scale must have the same length as mat_b",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale_b.dim() == 2,
|
||||
lambda: f"scale must be a 2D tensor for 3D mat_b, but got {scale_b.dim()}D",
|
||||
)
|
||||
torch._check(
|
||||
scale_b.stride(1) == 1,
|
||||
lambda: "scale_b must be contiguous in the last dimension",
|
||||
)
|
||||
torch._check(
|
||||
scale_b.size(0) == mat_b.size(0),
|
||||
lambda: "scale must have the same batch dimension as mat_b",
|
||||
)
|
||||
torch._check(
|
||||
scale_b.size(1) == mat_b.size(2),
|
||||
lambda: "scale must have the same last dimension as mat_b",
|
||||
)
|
||||
|
||||
# Check bias
|
||||
torch._check(bias is None, lambda: "Bias not supported yet")
|
||||
|
||||
# Check output dtype
|
||||
out_dtype_ = out_dtype if out_dtype is not None else mat_a.dtype
|
||||
torch._check(
|
||||
out_dtype_ == torch.bfloat16,
|
||||
lambda: "Only bf16 high precision output types are supported for grouped gemm",
|
||||
)
|
||||
|
||||
# Compute output size
|
||||
out_size = _compute_grouped_gemm_output_size(mat_a, mat_b, offs)
|
||||
out = mat_a.new_empty(out_size, dtype=out_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten._softmax)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@ -50,9 +51,8 @@ def has_triton_tma_device():
|
||||
):
|
||||
# old API
|
||||
try:
|
||||
from triton.language.extra.cuda import ( # noqa: F401
|
||||
experimental_device_tensormap_create1d,
|
||||
experimental_device_tensormap_create2d,
|
||||
from triton.language import ( # noqa: F401
|
||||
_experimental_make_tensor_descriptor,
|
||||
)
|
||||
|
||||
return True
|
||||
@ -122,3 +122,20 @@ def triton_hash_with_backend():
|
||||
|
||||
# Hash is upper case so that it can't contain any Python keywords.
|
||||
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def triton_set_allocator(device):
|
||||
if has_triton_tma_device():
|
||||
import torch
|
||||
|
||||
assert torch.cuda.current_device() == device
|
||||
|
||||
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
|
||||
return torch.empty(size, device=device, dtype=torch.int8)
|
||||
|
||||
import triton
|
||||
|
||||
triton.set_allocator(alloc_fn)
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user