[Inductor][Triton][FP8] Refactor scaled_mm template to accept scaling mode (#164318)

Summary: Refactor `scaled_mm` Inductor template to support template choice based on scaling mode. This modification sets up the infrastructure for adding new templates based on new scaling modes, such as deepseek-style scaling (a follow-up diff), as new scaling modes (deepseek, block, group) scale before the accumulation (as opposed to per-tensor and per-row scaling, which apply scaling after accumulation). This modification also further enables Inductor to infer a scaling type based on the shape of the scaling tensors, which makes existing infrastructure more extensible to new scaling modes.

Test Plan:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling_rowwise --atol=20 --rtol=2 2>&1 | tee ~/personal/random.log
```

bifferential Revision: D83591083

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164318
Approved by: https://github.com/drisspg, https://github.com/slayton58
This commit is contained in:
Janani Sriram
2025-10-16 20:40:45 +00:00
committed by PyTorch MergeBot
parent aba8c43594
commit 9bf5b38c14
3 changed files with 88 additions and 31 deletions

View File

@ -623,7 +623,8 @@ class TestFP8Lowering(TestCase):
bias, bias,
) )
FileCheck().check("SCALING_ROWWISE : tl.constexpr = False").run(code[0]) FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 0").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 0").run(code[0])
self.assertEqual(y_eager.dtype, dtype) self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype) self.assertEqual(y_compiled.dtype, dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor # depending on the kernel config (BLOCK_M size, etc) selected during Inductor
@ -768,7 +769,8 @@ class TestFP8Lowering(TestCase):
bias, bias,
) )
FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0]) FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 1").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 1").run(code[0])
self.assertEqual(y_eager.dtype, dtype) self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype) self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)

View File

@ -16,6 +16,7 @@ from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
from torch._inductor.remote_gemm_autotune_cache import gen_best_config from torch._inductor.remote_gemm_autotune_cache import gen_best_config
from torch._inductor.virtualized import ops, V from torch._inductor.virtualized import ops, V
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
from torch.torch_version import TorchVersion from torch.torch_version import TorchVersion
from .. import config as inductor_config from .. import config as inductor_config
@ -372,15 +373,11 @@ persistent_tma_mm_template = TritonTemplate(
load_scales = r""" load_scales = r"""
@triton.jit @triton.jit
def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr):
if SCALING_ROWWISE: if SCALE_RECIPE == 0:
# For row-wise scaling, we'll return the pointers return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values
return a_scale_ptr, b_scale_ptr
else: else:
# For per-tensor scaling, we'll load the scalar values return scale_ptr # For all other scaling recipes, we'll return the pointers
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr)
return a_scale, b_scale
""" """
@ -390,7 +387,8 @@ def apply_scaling(
accumulator, accumulator,
a_scale, a_scale,
b_scale, b_scale,
SCALING_ROWWISE: tl.constexpr, SCALE_RECIPE_A: tl.constexpr,
SCALE_RECIPE_B: tl.constexpr,
offs_cm, offs_cm,
offs_cn, offs_cn,
M, M,
@ -398,7 +396,7 @@ def apply_scaling(
stride_a_scale_m, stride_a_scale_m,
stride_b_scale_n, stride_b_scale_n,
): ):
if SCALING_ROWWISE: if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise)
# For row-wise scaling, we need to load the scales for each row/column # For row-wise scaling, we need to load the scales for each row/column
a_scales = tl.load( a_scales = tl.load(
a_scale + (offs_cm * stride_a_scale_m), a_scale + (offs_cm * stride_a_scale_m),
@ -411,7 +409,7 @@ def apply_scaling(
other=0.0, other=0.0,
) )
acc_scale = a_scales[:, None] * b_scales[None, :] acc_scale = a_scales[:, None] * b_scales[None, :]
else: else: # (ScalingType.TensorWise, ScalingType.TensorWise)
# For per-tensor scaling, we can directly use the loaded scalar values # For per-tensor scaling, we can directly use the loaded scalar values
acc_scale = a_scale * b_scale acc_scale = a_scale * b_scale
@ -419,7 +417,7 @@ def apply_scaling(
""" """
device_tma = r""" scaled_mm_device_tma_epilogue_scaling = r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} {{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}} M = {{size("A", 0)}}
N = {{size("B", 1)}} N = {{size("B", 1)}}
@ -433,11 +431,14 @@ device_tma = r"""
stride_bk = {{stride("B", 0)}} stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}} stride_bn = {{stride("B", 1)}}
if SCALING_ROWWISE: if SCALE_RECIPE_A == 1: # ScalingType.RowWise
stride_a_scale_m = 1 stride_a_scale_m = 1
stride_b_scale_n = 1
else: else:
stride_a_scale_m = 0 stride_a_scale_m = 0
if SCALE_RECIPE_B == 1: # ScalingType.RowWise
stride_b_scale_n = 1
else:
stride_b_scale_n = 0 stride_b_scale_n = 0
start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) start_pid = tl.program_id(axis=0).to(INDEX_DTYPE)
@ -500,7 +501,8 @@ device_tma = r"""
num_pid_in_group = GROUP_M * num_pid_n num_pid_in_group = GROUP_M * num_pid_n
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A)
b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B)
for _ in range(0, k_tiles * tiles_per_SM): for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1) ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
@ -542,7 +544,8 @@ device_tma = r"""
accumulator, accumulator,
a_scale, a_scale,
b_scale, b_scale,
SCALING_ROWWISE, SCALE_RECIPE_A,
SCALE_RECIPE_B,
offs_cm, offs_cm,
offs_cn, offs_cn,
M, M,
@ -570,10 +573,10 @@ device_tma = r"""
""" """
scaled_mm_device_tma_template = TritonTemplate( scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate(
name="scaled_mm_device_tma", name="scaled_mm_device_tma_epilogue_scaling",
grid=persistent_mm_grid, grid=persistent_mm_grid,
source=device_tma + load_scales + apply_scaling, source=scaled_mm_device_tma_epilogue_scaling + load_scales + apply_scaling,
) )
_compute_blackwell_pid = r""" _compute_blackwell_pid = r"""
@ -1319,6 +1322,38 @@ def tuned_sparse_semi_structured_mm(
) )
scaling_pairs = [
(ScalingType.TensorWise, ScalingType.TensorWise),
(ScalingType.RowWise, ScalingType.RowWise),
]
def _is_tensorwise_scaling(sz: Any) -> bool:
return (len(sz) == 0) or all(
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
)
def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
idx = 0 if transpose else -1
return V.graph.sizevars.statically_known_equals(sz[idx], 1)
def is_desired_scaling(
t: torch.Tensor,
scale_size: torch.Tensor,
scaling_type: ScalingType,
transpose: bool = False,
) -> bool:
match scaling_type:
case ScalingType.TensorWise:
return _is_tensorwise_scaling(scale_size)
case ScalingType.RowWise:
return _is_rowwise_scaling(scale_size, transpose)
case _:
raise AssertionError(f"Unsupported scaling type {scaling_type}")
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] @register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
def tuned_scaled_mm( def tuned_scaled_mm(
mat_a, mat_a,
@ -1404,8 +1439,29 @@ def tuned_scaled_mm(
# TODO (paulzhan): There is no template that exists for bias and TMA # TODO (paulzhan): There is no template that exists for bias and TMA
# Don't run tma template currently if bias exist # Don't run tma template currently if bias exist
if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias: if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias:
templates_to_use.append(scaled_mm_device_tma_template) scale_a_size, scale_b_size = scale_a_real.shape, scale_b_real.shape
kwarg_overrides[scaled_mm_device_tma_template.uid] = overriders
for scale_option_a, scale_option_b in scaling_pairs:
if is_desired_scaling(
mat_a, scale_a_size, scale_option_a
) and is_desired_scaling(
mat_b, scale_b_size, scale_option_b, transpose=True
):
overriders["SCALE_RECIPE_A"] = scale_option_a.value
overriders["SCALE_RECIPE_B"] = scale_option_b.value
break
if (
"SCALE_RECIPE_A" not in overriders
): # verify that shapes are supported by at least one existing pairing
raise AssertionError(
f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
)
templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
overriders
)
if ( if (
use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout) use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)

View File

@ -21,7 +21,7 @@ from ..kernel.mm import (
blackwell_ws_persistent_device_tma_mm_template, blackwell_ws_persistent_device_tma_mm_template,
mm_template, mm_template,
persistent_tma_mm_template, persistent_tma_mm_template,
scaled_mm_device_tma_template, scaled_mm_device_tma_epilogue_scaling_template,
) )
from ..kernel.mm_plus_mm import mm_plus_mm_template from ..kernel.mm_plus_mm import mm_plus_mm_template
from ..kernel_inputs import KernelInputs, MMKernelInputs from ..kernel_inputs import KernelInputs, MMKernelInputs
@ -1847,7 +1847,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
Generate scaled MM template configs with scaled MM-specific options. Generate scaled MM template configs with scaled MM-specific options.
Handles the remaining logic from mm_common including assertions and SCALING_ROWWISE. Handles the remaining logic from mm_common, including assertions.
""" """
kernel_inputs = self.adjust_kernel_inputs(kernel_inputs, op_name) kernel_inputs = self.adjust_kernel_inputs(kernel_inputs, op_name)
input_nodes = kernel_inputs.nodes() input_nodes = kernel_inputs.nodes()
@ -1897,9 +1897,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options) # Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
# Override accumulator type for scaled MM # Override accumulator type for scaled MM
template_kwargs["ACC_TYPE"] = "tl.float32" template_kwargs["ACC_TYPE"] = "tl.float32"
# Add SCALING_ROWWISE attribute based on scale tensor shapes
both_scalar_like = is_scalar_like(size_a) and is_scalar_like(size_b)
template_kwargs["SCALING_ROWWISE"] = not both_scalar_like
yield template_kwargs yield template_kwargs
@ -2127,13 +2124,15 @@ class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeurist
@register_template_heuristic( @register_template_heuristic(
scaled_mm_device_tma_template.uid, scaled_mm_device_tma_epilogue_scaling_template.uid,
"cuda", "cuda",
register=torch.version.hip is None, register=torch.version.hip is None,
op_name="scaled_mm", op_name="scaled_mm",
) )
class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic): class CUDAScaledTMAEpilogueScalingTemplateConfigHeuristic(
"""Scaled TMA template heuristic for CUDA""" ScaledTMAConfigMixin, CUDAConfigHeuristic
):
"""Scaled TMA template heuristic for CUDA: epilogue scaling variants (TensorWise, RowWise)"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()