Files
pytorch/torch/_inductor/kernel/mm_grouped.py
Yavuz Yetim 7afcb030d8 Back out "Revert D81959389" (#163905)
Summary:
Original commit changeset: 06888d7ebff0

Original Phabricator Diff: D82932788

Restricted the test to SM90 for scaled_grouped_mm

Test Plan: TBD (will share the linux CI results)

Differential Revision: D83283991

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163905
Approved by: https://github.com/angelayi
2025-09-30 07:05:13 +00:00

757 lines
22 KiB
Python

# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
from ..ir import ChoiceCaller, Layout, TensorBox
from ..lowering import register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
realize_inputs,
TritonTemplate,
)
from ..utils import (
get_gpu_shared_memory,
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
persistent_grouped_mm_grid,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten
@dataclass
class Config:
kwargs: dict[str, int]
num_stages: int
num_warps: int
_NV_CONFIGS = [
Config(
{
"BLOCK_M": block_size_m,
"BLOCK_N": block_size_n,
"BLOCK_K": block_size_k,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=num_stages,
num_warps=num_warps,
)
for block_size_m in [16, 32, 64, 128]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
]
def grouped_mm_configs():
return _NV_CONFIGS
def early_config_prune(g, m, configs, named_args):
dtsize = 1
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = (
kw["BLOCK_M"],
kw["BLOCK_N"],
kw["BLOCK_K"],
config.num_stages,
config.num_warps,
getattr(config, "num_consumer_groups", 0),
)
# 1. Prune NV configs depending on g and m.
if not has_free_symbols((g, m)):
a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"]
m_avg = m // g if a_is_2d and not b_is_2d else m
if m_avg <= 16:
if BLOCK_M > 32:
continue
elif m_avg <= 32:
if BLOCK_M > 64:
continue
elif m_avg <= 64:
if BLOCK_M <= 16:
continue
else:
if BLOCK_M <= 32:
continue
# 2. make sure we have enough smem
max_shared_memory = get_gpu_shared_memory()
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
use_warp_specialization = num_consumer_groups >= 1
# 3. make sure we can partition for ws
if use_warp_specialization:
if num_warps != 4:
continue
# "tritongpu-warp-spec-data-partition"
m_slice = BLOCK_M // num_consumer_groups
n_slice = BLOCK_N // num_consumer_groups
if m_slice < 64 and n_slice < 256:
continue
pruned_configs.append(config)
return pruned_configs
triton_grouped_mm_source = r"""
{%- if SCALED %}
{%- if A_IS_2D or B_IS_2D %}
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}}
{%- else %}
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}}
{%- endif %}
{%- else %}
{%- if A_IS_2D or B_IS_2D %}
{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}}
{%- else %}
{{def_kernel("a_ptr", "b_ptr")}}
{%- endif %}
{%- endif %}
tidx = tl.program_id(0).to(INDEX_DTYPE)
{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %}
{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %}
{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %}
{%- if A_IS_2D %}
{%- if B_IS_2D %}
G = {{size("offsets_ptr", 0)}}
{%- else %}
G = {{size("b_ptr", 0)}}
{%- endif %}
{%- else %}
{%- if B_IS_2D %}
G = {{size("a_ptr", 0)}}
{%- else %}
G = {{size("a_ptr", 0)}}
{%- endif %}
{%- endif %}
# the b_ptr tensor is given with its last two dims transposed, revert here
M = {{size("a_ptr", -2)}}
N = {{size("b_ptr", -1)}}
K = {{size("a_ptr", -1)}}
A_STRIDE_M = {{stride("a_ptr", -2)}}
A_STRIDE_K = {{stride("a_ptr", -1)}}
{%- if not A_IS_2D %}
A_STRIDE_G = {{stride("a_ptr", 0)}}
{%- if SCALED %}
SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}}
{%- endif %}
{%- endif %}
B_STRIDE_N = {{stride("b_ptr", -1)}}
B_STRIDE_K = {{stride("b_ptr", -2)}}
{%- if not B_IS_2D %}
B_STRIDE_G = {{stride("b_ptr", 0)}}
{%- if SCALED %}
SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}}
{%- endif %}
{%- endif %}
{%- if USE_TMA_LOAD %}
{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %}
a_desc = tl._experimental_make_tensor_descriptor(
{%- else %}
a_desc = tl.make_tensor_descriptor(
{%- endif %}
a_ptr,
{%- if A_IS_2D %}
shape=[M, K],
# fixme: strides=[A_STRIDE_M, A_STRIDE_K],
strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
block_shape=[BLOCK_M, BLOCK_K],
{%- else %}
shape=[G, M, K],
# fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
block_shape=[1, BLOCK_M, BLOCK_K],
{%- endif %}
)
{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %}
b_desc = tl._experimental_make_tensor_descriptor(
{%- else %}
b_desc = tl.make_tensor_descriptor(
{%- endif %}
b_ptr,
{%- if B_IS_2D %}
shape=[N, K],
# fixme: strides=[B_STRIDE_N, B_STRIDE_K],
strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
block_shape=[BLOCK_N, BLOCK_K],
{%- else %}
shape=[G, N, K],
# fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
block_shape=[1, BLOCK_N, BLOCK_K],
{%- endif %}
)
{%- endif %}
{%- if M_IS_VARYING %}
m_end_offset = 0
{%- endif %}
{%- if N_IS_VARYING %}
n_end_offset = 0
{%- endif %}
{%- if K_IS_VARYING %}
k_end_offset = 0
{%- endif %}
iterated_tiles = 0
for g in tl.range(G):
{%- if M_IS_VARYING %}
# Move across groups
m_start_offset = m_end_offset
m_end_offset = tl.load(offsets_ptr + g)
m_size = m_end_offset - m_start_offset
{%- if SCALED %}
m_scale_start_offset = m_start_offset
{%- endif %}
{%- else %}
m_start_offset = 0
m_size = M
{%- if SCALED %}
m_scale_start_offset = g * M
{%- endif %}
{%- endif %}
{%- if N_IS_VARYING %}
# Move across groups
n_start_offset = n_end_offset
n_end_offset = tl.load(offsets_ptr + g)
n_size = n_end_offset - n_start_offset
{%- if SCALED %}
n_scale_start_offset = n_start_offset
{%- endif %}
{%- else %}
n_start_offset = 0
n_size = N
{%- if SCALED %}
n_scale_start_offset = g * N
{%- endif %}
{%- endif %}
if m_size > 0 and n_size > 0:
{%- if K_IS_VARYING %}
# Move across groups
k_start_offset = k_end_offset
k_end_offset = tl.load(offsets_ptr + g)
k_size = k_end_offset - k_start_offset
{%- else %}
k_start_offset = 0
k_size = K
{%- endif %}
num_m_tiles = tl.cdiv(m_size, BLOCK_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_N)
num_tiles = num_m_tiles * num_n_tiles
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
{%- if USE_TMA_LOAD %}
m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
for k_offset in range(0, k_size, BLOCK_K):
{%- if A_IS_2D %}
a = a_desc.load([m_offset, k_start_offset + k_offset])
{%- else %}
a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K)
{%- endif %}
{%- if B_IS_2D %}
b = b_desc.load([n_offset, k_start_offset + k_offset])
{%- else %}
b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K)
{%- endif %}
{%- if K_IS_VARYING %}
if k_offset + BLOCK_K > k_size:
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
a = tl.where(group_offs_k < k_size, a, 0)
b = tl.where(group_offs_k < k_size, b, 0)
{%- endif %}
{%- if USE_FAST_ACCUM %}
accumulator = tl.dot(a, b.T, accumulator)
{%- else %}
accumulator += tl.dot(a, b.T)
{%- endif %}
{%- else %}
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
for k_offset in range(0, k_size, BLOCK_K):
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
offs_k = group_offs_k + k_start_offset
a_ptrs = (
a_ptr
{%- if not A_IS_2D %}
+ g * A_STRIDE_G
{%- endif %}
+ (m_start_offset + offs_am[:, None]) * A_STRIDE_M
+ offs_k[None, :] * A_STRIDE_K
)
b_ptrs = (
b_ptr
{%- if not B_IS_2D %}
+ g * B_STRIDE_G
{%- endif %}
+ (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
+ offs_k[None, :] * B_STRIDE_K
)
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
if k_offset + BLOCK_K > k_size:
a = tl.where(group_offs_k < k_size, a, 0)
b = tl.where(group_offs_k < k_size, b, 0)
{%- if USE_FAST_ACCUM %}
accumulator = tl.dot(a, b.T, accumulator)
{%- else %}
accumulator += tl.dot(a, b.T)
{%- endif %}
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K
{%- endif %}
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
{%- if SCALED %}
scale_a = tl.load(
scale_a_ptr
{%- if A_IS_2D %}
+ m_scale_start_offset
{%- else %}
+ g * SCALE_A_STRIDE_G
{%- endif %}
+ offs_am[:, None],
mask=offs_am[:, None] < m_size,
)
scale_b = tl.load(
scale_b_ptr
{%- if B_IS_2D %}
+ n_scale_start_offset
{%- else %}
+ g * SCALE_B_STRIDE_G
{%- endif %}
+ offs_bn[None, :],
mask=offs_bn[None, :] < n_size,
)
c = accumulator.to(tl.float32) * scale_a * scale_b
{%- else %}
c = accumulator.to(tl.float32)
{%- endif %}
{%- if M_IS_VARYING %}
idx_m = (m_start_offset + offs_am[:, None])
{%- else %}
idx_m = offs_am[:, None]
{%- endif %}
{%- if N_IS_VARYING %}
idx_n = (n_start_offset + offs_bn[None, :])
{%- else %}
idx_n = offs_bn[None, :]
{%- endif %}
mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < n_size)
{%- if M_IS_VARYING or N_IS_VARYING %}
{{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}}
{%- else %}
{{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}}
{%- endif %}
tidx += NUM_SMS
iterated_tiles += num_tiles
"""
triton_grouped_mm_template = TritonTemplate(
name="grouped_mm",
grid=persistent_grouped_mm_grid,
source=triton_grouped_mm_source,
)
triton_scaled_grouped_mm_template = TritonTemplate(
name="scaled_grouped_mm",
grid=persistent_grouped_mm_grid,
source=triton_grouped_mm_source,
)
def grouped_mm_args(
mat1: TensorBox,
mat2: TensorBox,
offs: Optional[TensorBox],
layout=None,
out_dtype=None,
):
mat1, mat2 = realize_inputs(mat1, mat2)
if offs is not None:
realize_inputs(offs)
mat1_size = mat1.get_size()
mat2_size = mat2.get_size()
m1dim, m2dim = len(mat1_size), len(mat2_size)
assert m1dim == 2 or m1dim == 3
assert m2dim == 2 or m2dim == 3
if layout is None:
from torch._inductor.ir import FixedLayout
if out_dtype is None:
out_dtype = mat1.get_dtype()
alignment = 16 // out_dtype.itemsize
if m1dim == 2:
if m2dim == 2:
assert offs is not None
out_size = [offs.get_size()[0], mat1_size[0], mat2_size[1]]
else:
out_size = [mat1_size[0], mat2_size[-1]]
else:
if m2dim == 2:
out_size = [mat1_size[1], mat2_size[1]]
else:
out_size = [mat1_size[0], mat1_size[1], mat2_size[-1]]
size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
if len(out_size) == 2:
out_stride = [size_padded, 1]
else:
out_stride = [out_size[1] * size_padded, size_padded, 1]
layout = FixedLayout(
mat1.get_device(),
out_dtype,
out_size,
out_stride,
)
else:
assert out_dtype is None, "out_dtype is ignored if layout is specified."
return (mat1_size, mat2_size, layout, mat1, mat2, offs)
aten__grouped_mm = ExternKernelChoice(
torch._grouped_mm,
"at::_grouped_mm",
op_overload=aten._grouped_mm.default,
has_out_variant=False,
)
aten__scaled_grouped_mm = ExternKernelChoice(
torch._scaled_grouped_mm,
"at::_scaled_grouped_mm",
op_overload=aten._scaled_grouped_mm.default,
has_out_variant=False,
)
def can_use_triton_kernel(
mat_a: TensorBox,
mat_b: TensorBox,
offs: Optional[TensorBox],
bias: Optional[TensorBox],
scale_result: Optional[TensorBox],
) -> bool:
if not (
torch.cuda.is_available()
and torch.cuda.get_device_capability() == (9, 0)
and not torch.version.hip
):
return False
if not has_triton():
return False
# The _grouped_mm()/_scaled_grouped_mm() operator do not support
# bias nor scale_result yet.
if bias is not None:
return False
if scale_result is not None:
return False
if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2:
return offs is not None
else:
return offs is None
def create_offsets(x, m1_size, m2_size, offs_size):
m1_is_2d = len(m1_size) == 2
m2_is_2d = len(m2_size) == 2
if m1_is_2d:
if m2_is_2d:
k = V.graph.sizevars.size_hint(m1_size[1])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = k / noffs
return torch.linspace(
step, k, noffs, dtype=x.get_dtype(), device=x.get_device()
)
else:
m = V.graph.sizevars.size_hint(m1_size[0])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = m / noffs
return torch.linspace(
step, m, noffs, dtype=x.get_dtype(), device=x.get_device()
)
else:
if m2_is_2d:
n = V.graph.sizevars.size_hint(m2_size[0])
noffs = V.graph.sizevars.size_hint(offs_size[0])
step = n / noffs
return torch.linspace(
step, n, noffs, dtype=x.get_dtype(), device=x.get_device()
)
else:
return None
def _tuned_grouped_mm_common(
operator_name: str,
algorithm_name: str,
extern_kernel_choice: ExternKernelChoice,
kernel_template: TritonTemplate,
mat_a: TensorBox,
mat_b: TensorBox,
scale_a: Optional[TensorBox] = None,
scale_b: Optional[TensorBox] = None,
offs: Optional[TensorBox] = None,
bias: Optional[TensorBox] = None,
scale_result: Optional[TensorBox] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: Optional[bool] = None,
layout: Optional[Layout] = None,
) -> TensorBox:
assert (scale_a is None) == (scale_b is None)
assert scale_result is None or scale_a is not None
m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args(
mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype
)
counters["aten_mm_info"][operator_name] += 1
log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s"
log.info(
log_message,
m1_size,
m2_size,
mat_a.get_dtype(),
mat_b.get_dtype(),
layout,
)
if scale_a is not None and scale_b is not None:
check_supported_striding(mat_a, mat_b)
# workaround for Inductor not supporting optional tensor input arguments
input_nodes: list[Any] = [mat_a, mat_b]
if scale_a is not None:
input_nodes.append(realize_inputs(scale_a))
if scale_b is not None:
input_nodes.append(realize_inputs(scale_b))
if offs is not None:
input_nodes.append(realize_inputs(offs))
if use_fast_accum is None:
aten_choice = extern_kernel_choice.bind(
input_nodes,
layout,
out_dtype=out_dtype,
)
else:
aten_choice = extern_kernel_choice.bind(
input_nodes,
layout,
out_dtype=out_dtype,
use_fast_accum=use_fast_accum,
)
if use_fast_accum is None:
use_fast_accum = False
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.append(aten_choice)
_, is_nonzero = _is_static_problem(layout)
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor")
triton_has_experimental_make_tensor_descriptor = hasattr(
tl, "_experimental_make_tensor_descriptor"
)
use_tma_load = (
triton_has_make_tensor_descriptor
or triton_has_experimental_make_tensor_descriptor
)
# The make_tensor_descriptor imposes this additional limitation.
use_tma_load = use_tma_load and (
mat_a.get_stride()[-1] == 1 and mat_b.get_stride()[-2] == 1
)
kwargs = {
"SCALED": scaled,
"A_IS_2D": a_is_2d,
"B_IS_2D": b_is_2d,
"USE_FAST_ACCUM": use_fast_accum,
"NUM_SMS": get_num_sms(),
"USE_TMA_LOAD": use_tma_load,
"USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor,
}
for config in early_config_prune(g, m, grouped_mm_configs(), kwargs):
kernel_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
num_stages=config.num_stages,
num_warps=config.num_warps,
**kwargs,
**config.kwargs,
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None
),
}
return autotune_select_algorithm(
algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns
)
@register_lowering(aten._grouped_mm.default, type_promotion_kind=None)
def tuned_grouped_mm(
mat_a: TensorBox,
mat_b: TensorBox,
offs: Optional[TensorBox] = None,
bias: Optional[TensorBox] = None,
out_dtype: Optional[torch.dtype] = None,
layout: Optional[Layout] = None,
) -> TensorBox:
"""Auto-tuning for _grouped_mm() operator."""
return _tuned_grouped_mm_common(
"aten._grouped_mm.default",
"grouped_mm",
aten__grouped_mm,
triton_grouped_mm_template,
mat_a,
mat_b,
None,
None,
offs,
bias,
None,
out_dtype,
None,
layout,
)
@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None)
def tuned_scaled_grouped_mm(
mat_a: TensorBox,
mat_b: TensorBox,
scale_a: TensorBox,
scale_b: TensorBox,
offs: Optional[TensorBox] = None,
bias: Optional[TensorBox] = None,
scale_result: Optional[TensorBox] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
layout: Optional[Layout] = None,
) -> TensorBox:
"""Auto-tuning for _scaled_grouped_mm() operator."""
# matching _scaled_grouped_mm_cuda Blas.cpp implementation
out_dtype = out_dtype or torch.bfloat16
return _tuned_grouped_mm_common(
"aten._scaled_grouped_mm.default",
"scaled_grouped_mm",
aten__scaled_grouped_mm,
triton_scaled_grouped_mm_template,
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
scale_result,
out_dtype,
use_fast_accum,
layout,
)