mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][Triton][FP8] Refactor scaled_mm template to accept scaling mode (#164318)
Summary: Refactor `scaled_mm` Inductor template to support template choice based on scaling mode. This modification sets up the infrastructure for adding new templates based on new scaling modes, such as deepseek-style scaling (a follow-up diff), as new scaling modes (deepseek, block, group) scale before the accumulation (as opposed to per-tensor and per-row scaling, which apply scaling after accumulation). This modification also further enables Inductor to infer a scaling type based on the shape of the scaling tensors, which makes existing infrastructure more extensible to new scaling modes. Test Plan: ``` TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling_rowwise --atol=20 --rtol=2 2>&1 | tee ~/personal/random.log ``` bifferential Revision: D83591083 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164318 Approved by: https://github.com/drisspg, https://github.com/slayton58
This commit is contained in:
committed by
PyTorch MergeBot
parent
aba8c43594
commit
9bf5b38c14
@ -623,7 +623,8 @@ class TestFP8Lowering(TestCase):
|
||||
bias,
|
||||
)
|
||||
|
||||
FileCheck().check("SCALING_ROWWISE : tl.constexpr = False").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 0").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 0").run(code[0])
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
|
||||
@ -768,7 +769,8 @@ class TestFP8Lowering(TestCase):
|
||||
bias,
|
||||
)
|
||||
|
||||
FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 1").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 1").run(code[0])
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
|
@ -16,6 +16,7 @@ from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
|
||||
from torch._inductor.remote_gemm_autotune_cache import gen_best_config
|
||||
from torch._inductor.virtualized import ops, V
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .. import config as inductor_config
|
||||
@ -372,15 +373,11 @@ persistent_tma_mm_template = TritonTemplate(
|
||||
|
||||
load_scales = r"""
|
||||
@triton.jit
|
||||
def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr):
|
||||
if SCALING_ROWWISE:
|
||||
# For row-wise scaling, we'll return the pointers
|
||||
return a_scale_ptr, b_scale_ptr
|
||||
def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr):
|
||||
if SCALE_RECIPE == 0:
|
||||
return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values
|
||||
else:
|
||||
# For per-tensor scaling, we'll load the scalar values
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr)
|
||||
return a_scale, b_scale
|
||||
return scale_ptr # For all other scaling recipes, we'll return the pointers
|
||||
"""
|
||||
|
||||
|
||||
@ -390,7 +387,8 @@ def apply_scaling(
|
||||
accumulator,
|
||||
a_scale,
|
||||
b_scale,
|
||||
SCALING_ROWWISE: tl.constexpr,
|
||||
SCALE_RECIPE_A: tl.constexpr,
|
||||
SCALE_RECIPE_B: tl.constexpr,
|
||||
offs_cm,
|
||||
offs_cn,
|
||||
M,
|
||||
@ -398,7 +396,7 @@ def apply_scaling(
|
||||
stride_a_scale_m,
|
||||
stride_b_scale_n,
|
||||
):
|
||||
if SCALING_ROWWISE:
|
||||
if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise)
|
||||
# For row-wise scaling, we need to load the scales for each row/column
|
||||
a_scales = tl.load(
|
||||
a_scale + (offs_cm * stride_a_scale_m),
|
||||
@ -411,7 +409,7 @@ def apply_scaling(
|
||||
other=0.0,
|
||||
)
|
||||
acc_scale = a_scales[:, None] * b_scales[None, :]
|
||||
else:
|
||||
else: # (ScalingType.TensorWise, ScalingType.TensorWise)
|
||||
# For per-tensor scaling, we can directly use the loaded scalar values
|
||||
acc_scale = a_scale * b_scale
|
||||
|
||||
@ -419,7 +417,7 @@ def apply_scaling(
|
||||
"""
|
||||
|
||||
|
||||
device_tma = r"""
|
||||
scaled_mm_device_tma_epilogue_scaling = r"""
|
||||
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
|
||||
M = {{size("A", 0)}}
|
||||
N = {{size("B", 1)}}
|
||||
@ -433,11 +431,14 @@ device_tma = r"""
|
||||
stride_bk = {{stride("B", 0)}}
|
||||
stride_bn = {{stride("B", 1)}}
|
||||
|
||||
if SCALING_ROWWISE:
|
||||
if SCALE_RECIPE_A == 1: # ScalingType.RowWise
|
||||
stride_a_scale_m = 1
|
||||
stride_b_scale_n = 1
|
||||
else:
|
||||
stride_a_scale_m = 0
|
||||
|
||||
if SCALE_RECIPE_B == 1: # ScalingType.RowWise
|
||||
stride_b_scale_n = 1
|
||||
else:
|
||||
stride_b_scale_n = 0
|
||||
|
||||
start_pid = tl.program_id(axis=0).to(INDEX_DTYPE)
|
||||
@ -500,7 +501,8 @@ device_tma = r"""
|
||||
|
||||
num_pid_in_group = GROUP_M * num_pid_n
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE)
|
||||
a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A)
|
||||
b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B)
|
||||
|
||||
for _ in range(0, k_tiles * tiles_per_SM):
|
||||
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
||||
@ -542,7 +544,8 @@ device_tma = r"""
|
||||
accumulator,
|
||||
a_scale,
|
||||
b_scale,
|
||||
SCALING_ROWWISE,
|
||||
SCALE_RECIPE_A,
|
||||
SCALE_RECIPE_B,
|
||||
offs_cm,
|
||||
offs_cn,
|
||||
M,
|
||||
@ -570,10 +573,10 @@ device_tma = r"""
|
||||
"""
|
||||
|
||||
|
||||
scaled_mm_device_tma_template = TritonTemplate(
|
||||
name="scaled_mm_device_tma",
|
||||
scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate(
|
||||
name="scaled_mm_device_tma_epilogue_scaling",
|
||||
grid=persistent_mm_grid,
|
||||
source=device_tma + load_scales + apply_scaling,
|
||||
source=scaled_mm_device_tma_epilogue_scaling + load_scales + apply_scaling,
|
||||
)
|
||||
|
||||
_compute_blackwell_pid = r"""
|
||||
@ -1319,6 +1322,38 @@ def tuned_sparse_semi_structured_mm(
|
||||
)
|
||||
|
||||
|
||||
scaling_pairs = [
|
||||
(ScalingType.TensorWise, ScalingType.TensorWise),
|
||||
(ScalingType.RowWise, ScalingType.RowWise),
|
||||
]
|
||||
|
||||
|
||||
def _is_tensorwise_scaling(sz: Any) -> bool:
|
||||
return (len(sz) == 0) or all(
|
||||
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
|
||||
)
|
||||
|
||||
|
||||
def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
|
||||
idx = 0 if transpose else -1
|
||||
return V.graph.sizevars.statically_known_equals(sz[idx], 1)
|
||||
|
||||
|
||||
def is_desired_scaling(
|
||||
t: torch.Tensor,
|
||||
scale_size: torch.Tensor,
|
||||
scaling_type: ScalingType,
|
||||
transpose: bool = False,
|
||||
) -> bool:
|
||||
match scaling_type:
|
||||
case ScalingType.TensorWise:
|
||||
return _is_tensorwise_scaling(scale_size)
|
||||
case ScalingType.RowWise:
|
||||
return _is_rowwise_scaling(scale_size, transpose)
|
||||
case _:
|
||||
raise AssertionError(f"Unsupported scaling type {scaling_type}")
|
||||
|
||||
|
||||
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
|
||||
def tuned_scaled_mm(
|
||||
mat_a,
|
||||
@ -1404,8 +1439,29 @@ def tuned_scaled_mm(
|
||||
# TODO (paulzhan): There is no template that exists for bias and TMA
|
||||
# Don't run tma template currently if bias exist
|
||||
if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias:
|
||||
templates_to_use.append(scaled_mm_device_tma_template)
|
||||
kwarg_overrides[scaled_mm_device_tma_template.uid] = overriders
|
||||
scale_a_size, scale_b_size = scale_a_real.shape, scale_b_real.shape
|
||||
|
||||
for scale_option_a, scale_option_b in scaling_pairs:
|
||||
if is_desired_scaling(
|
||||
mat_a, scale_a_size, scale_option_a
|
||||
) and is_desired_scaling(
|
||||
mat_b, scale_b_size, scale_option_b, transpose=True
|
||||
):
|
||||
overriders["SCALE_RECIPE_A"] = scale_option_a.value
|
||||
overriders["SCALE_RECIPE_B"] = scale_option_b.value
|
||||
break
|
||||
|
||||
if (
|
||||
"SCALE_RECIPE_A" not in overriders
|
||||
): # verify that shapes are supported by at least one existing pairing
|
||||
raise AssertionError(
|
||||
f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
|
||||
)
|
||||
|
||||
templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
|
||||
kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
|
||||
overriders
|
||||
)
|
||||
|
||||
if (
|
||||
use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)
|
||||
|
@ -21,7 +21,7 @@ from ..kernel.mm import (
|
||||
blackwell_ws_persistent_device_tma_mm_template,
|
||||
mm_template,
|
||||
persistent_tma_mm_template,
|
||||
scaled_mm_device_tma_template,
|
||||
scaled_mm_device_tma_epilogue_scaling_template,
|
||||
)
|
||||
from ..kernel.mm_plus_mm import mm_plus_mm_template
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
@ -1847,7 +1847,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Generate scaled MM template configs with scaled MM-specific options.
|
||||
Handles the remaining logic from mm_common including assertions and SCALING_ROWWISE.
|
||||
Handles the remaining logic from mm_common, including assertions.
|
||||
"""
|
||||
kernel_inputs = self.adjust_kernel_inputs(kernel_inputs, op_name)
|
||||
input_nodes = kernel_inputs.nodes()
|
||||
@ -1897,9 +1897,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
|
||||
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
|
||||
# Override accumulator type for scaled MM
|
||||
template_kwargs["ACC_TYPE"] = "tl.float32"
|
||||
# Add SCALING_ROWWISE attribute based on scale tensor shapes
|
||||
both_scalar_like = is_scalar_like(size_a) and is_scalar_like(size_b)
|
||||
template_kwargs["SCALING_ROWWISE"] = not both_scalar_like
|
||||
|
||||
yield template_kwargs
|
||||
|
||||
@ -2127,13 +2124,15 @@ class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeurist
|
||||
|
||||
|
||||
@register_template_heuristic(
|
||||
scaled_mm_device_tma_template.uid,
|
||||
scaled_mm_device_tma_epilogue_scaling_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
op_name="scaled_mm",
|
||||
)
|
||||
class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic):
|
||||
"""Scaled TMA template heuristic for CUDA"""
|
||||
class CUDAScaledTMAEpilogueScalingTemplateConfigHeuristic(
|
||||
ScaledTMAConfigMixin, CUDAConfigHeuristic
|
||||
):
|
||||
"""Scaled TMA template heuristic for CUDA: epilogue scaling variants (TensorWise, RowWise)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user