Compare commits

...

2 Commits

Author SHA1 Message Date
4a4c81d42e [Inductor][Triton][FP8] Support tile-wise (1x128) scaling in Inductor (#165132)
Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
2025-10-20 08:32:44 -07:00
1e5a66eed1 [Inductor][Triton][FP8] Support deepseek-style scaling in Inductor (#164404)
Summary:

Support deepseek-style scaling in Inductor Triton for FP8 GEMMs. DeepSeek-style scaling is a colloquial term for a fine-grained mixed precision framework using FP8 to train [Deepseek-V3](https://arxiv.org/pdf/2412.19437), DeepSeek AI's recent MoE (Mixture of Experts) model. DeepSeek-style scaling effectively extends the dynamic range of FP8 by mitigating dequantization overhead under increased-precision accumulation, which is key to achieving more accurate FP8 GEMM results.

DeepSeek-style scaling on matmul `A @ B` leverages two different types of scaling strategies to preserve a balance between numerical stability and training efficiency:
- Activations (input tensor `A`): tile-wise (1x128 across shape `(M, K)`)
- Weights (input tensor `B`): block-wise (128x128 across shape `(N, K)`)

This diff enables Inductor users to replicate past successes with deepseek-style scaling and achieve higher numerical stability while increasing training efficiency.

NOTE: Block-wise 128x128 scaling is only supported in CUDA 12.9+; therefore, deepseek-style scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run deepseek-style scaling.

NOTE: Accuracy for FP8 is unstable, even with high tolerances, which is why TritonBench benchmarks are unlikely to be accurate against a `torch` implementation.

Test Plan:
Note that this command only works with the benchmark command (follow-up PR): in OSS PyTorch, run
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 4096 --n 768 --k 512 --output="{output_dir}/deepseek_bench.csv" --scaling-pair=BlockWise1x128,BlockWise128x128 --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/deepseek_style/deepseek_bench.log
```

Reviewed By: slayton58

Differential Revision: D83609850
2025-10-20 08:32:44 -07:00
5 changed files with 488 additions and 27 deletions

View File

@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import (
parametrize,
)
from torch.testing._internal.inductor_utils import (
_quantize_blockwise,
_quantize_rowwise,
_quantize_tensorwise,
_to_fp8_saturated,
@ -775,6 +776,91 @@ class TestFP8Lowering(TestCase):
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("shape", ((16,32,32), (1024,1024,512)))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"scaling_block_sizes", ((1,128,128,128), (1,128,1,128))
) # (BlockWise1x128, BlockWise128x128), (BlockWise1x128, BlockWise1x128)
def test_main_loop_scaling(
self,
shape: tuple[int, int, int],
use_fast_accum: bool,
scaling_block_sizes: tuple[int, int, int, int],
):
# Only bf16 output type is supported for non-tensorwise scaling, not fp32
dtype: torch.dtype = torch.bfloat16
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device)
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
am, ak, bn, bk = scaling_block_sizes
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_blockwise(
w, dtype_float8, block_outer=bn, block_inner=bk
)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_blockwise(
x, dtype_float8, block_outer=am, block_inner=ak
)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch(
{
"triton.enable_persistent_tma_matmul": True,
"test_configs.autotune_choice_name_regex": "triton_scaled_mm_device_tma",
"max_autotune_gemm_backends": "TRITON",
"max_autotune": True,
}
):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled, code = run_and_get_code(
linear_compiled,
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 2").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 3").run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 32, 1024))

View File

@ -42,6 +42,7 @@ from ..select_algorithm import (
)
from ..utils import (
_use_cutlass_for_op,
ceildiv,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_ck_tile_gemm_template,
@ -579,6 +580,224 @@ scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate(
source=scaled_mm_device_tma_epilogue_scaling + load_scales + apply_scaling,
)
blockwise1xTILESIZE_scaling = r"""
@triton.jit
def blockwise1xTILESIZE_scaling(
pid,
scale,
ki,
lhs_size,
lhs_blocks,
k_blocks,
BLOCK_lhs: tl.constexpr,
BLOCK_K: tl.constexpr,
MIN_BLOCK_TILE_K: tl.constexpr,
TILE_SIZE: tl.constexpr,
):
row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs)
col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE)
ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :]
mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks)
scale_block = tl.load(ptrs, mask=mask, other=1.0)
scale_expanded = scale_block[:, :, None]
scale_expanded = tl.broadcast_to(
scale_expanded,
(BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K)
)
scale_expanded = scale_expanded.reshape(
BLOCK_lhs,
((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K
)
return scale_expanded
"""
blockwise128x128_scaling = r"""
@triton.jit
def blockwise128x128_scaling(
pid,
scale,
ki,
lhs_blocks,
k_blocks,
BLOCK_lhs: tl.constexpr,
BLOCK_K: tl.constexpr,
MIN_BLOCK_TILE_lhs: tl.constexpr,
MIN_BLOCK_TILE_K: tl.constexpr,
):
row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128)
col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128)
ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :]
mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks)
scale_block = tl.load(ptrs, mask=mask, other=1.0)
scale_expanded = scale_block[:, :, None, None]
scale_expanded = tl.broadcast_to(
scale_expanded,
((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K)
)
scale_expanded = scale_expanded.reshape(
((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs,
((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K
)
return scale_expanded
"""
scaled_mm_device_tma_main_loop_scaling = r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_bn = {{stride("B", 1)}}
start_pid = tl.program_id(axis=0).to(INDEX_DTYPE)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K],
strides=[stride_am, 1],
block_shape=[BLOCK_M, BLOCK_K],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[N, K],
strides=[stride_bn, 1],
block_shape=[BLOCK_N, BLOCK_K],
)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_M * num_pid_n
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A)
b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = ki * BLOCK_K
a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k])
am_blocks = tl.cdiv(M, TILE_SIZE_A)
ak_blocks = tl.cdiv(K, TILE_SIZE_A)
bn_blocks = tl.cdiv(N, TILE_SIZE_B)
bk_blocks = tl.cdiv(K, TILE_SIZE_B)
{%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128
scale_a_block = blockwise128x128_scaling(
pid_m,
a_scale,
ki,
am_blocks,
ak_blocks,
BLOCK_M,
BLOCK_K,
MIN_BLOCK_TILE_AM,
MIN_BLOCK_TILE_AK,
)
{%- else %} # ScalingType.Blockwise1xTILESIZE
scale_a_block = blockwise1xTILESIZE_scaling(
pid_m,
a_scale,
ki,
M,
am_blocks,
ak_blocks,
BLOCK_M,
BLOCK_K,
MIN_BLOCK_TILE_AK,
TILE_SIZE_A,
)
{%- endif %}
{%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128
scale_b_block = blockwise128x128_scaling(
pid_n,
b_scale,
ki,
bn_blocks,
bk_blocks,
BLOCK_N,
BLOCK_K,
MIN_BLOCK_TILE_BN,
MIN_BLOCK_TILE_BK,
)
{%- else %} # ScalingType.Blockwise1xTILESIZE
scale_b_block = blockwise1xTILESIZE_scaling(
pid_n,
b_scale,
ki,
N,
bn_blocks,
bk_blocks,
BLOCK_N,
BLOCK_K,
MIN_BLOCK_TILE_BK,
TILE_SIZE_B,
)
{%- endif %}
a_scaled = a * scale_a_block
b_scaled = b * scale_b_block
accumulator = tl.dot(a_scaled, b_scaled.T, accumulator)
if ki == k_tiles - 1:
offs_cm = offs_am + tl.arange(0, BLOCK_M)
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
# inductor generates a suffix
{{store_output(
("offs_am", "offs_bn"),
"accumulator",
indent_width=12,
val_shape=("BLOCK_M", "BLOCK_N"),
block_indexing=True,
)}}
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
"""
scaled_mm_device_tma_main_loop_scaling_template = TritonTemplate(
name="scaled_mm_device_tma_main_loop_scaling",
grid=persistent_mm_grid,
source=scaled_mm_device_tma_main_loop_scaling
+ load_scales
+ blockwise1xTILESIZE_scaling
+ blockwise128x128_scaling,
)
_compute_blackwell_pid = r"""
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr):
@ -1325,9 +1544,15 @@ def tuned_sparse_semi_structured_mm(
scaling_pairs = [
(ScalingType.TensorWise, ScalingType.TensorWise),
(ScalingType.RowWise, ScalingType.RowWise),
(ScalingType.BlockWise1x128, ScalingType.BlockWise128x128),
(ScalingType.BlockWise1x128, ScalingType.BlockWise1x128),
]
epilogue_scaling_types = [ScalingType.TensorWise, ScalingType.RowWise]
main_loop_scaling_types = [ScalingType.BlockWise1x128, ScalingType.BlockWise128x128]
def _is_tensorwise_scaling(sz: Any) -> bool:
return (len(sz) == 0) or all(
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
@ -1339,8 +1564,24 @@ def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
return V.graph.sizevars.statically_known_equals(sz[idx], 1)
def _is_blockwise1xTILESIZE_scaling(sz: Any, tensor_sz: Any, tile_size: int, transpose: bool) -> bool:
lhs = 1 if transpose else 0
rhs = 0 if transpose else 1
return V.graph.sizevars.statically_known_equals(
sz[lhs], tensor_sz[lhs]
) and V.graph.sizevars.statically_known_equals(
sz[rhs], ceildiv(tensor_sz[rhs], tile_size)
)
def _is_blockwise128x128_scaling(sz: Any, tensor_sz: Any) -> bool:
return V.graph.sizevars.statically_known_equals(
sz[0], ceildiv(tensor_sz[0], 128)
) and V.graph.sizevars.statically_known_equals(sz[1], ceildiv(tensor_sz[1], 128))
def is_desired_scaling(
t: torch.Tensor,
t: Any,
scale_size: torch.Tensor,
scaling_type: ScalingType,
transpose: bool = False,
@ -1350,10 +1591,43 @@ def is_desired_scaling(
return _is_tensorwise_scaling(scale_size)
case ScalingType.RowWise:
return _is_rowwise_scaling(scale_size, transpose)
case ScalingType.BlockWise1x128:
return _is_blockwise1xTILESIZE_scaling(scale_size, t.get_size(), 128, transpose)
case ScalingType.BlockWise128x128:
return _is_blockwise128x128_scaling(scale_size, t.get_size())
case _:
raise AssertionError(f"Unsupported scaling type {scaling_type}")
def get_tile_size(scale_option) -> int:
match scale_option:
case ScalingType.BlockWise128x128:
return 128
case ScalingType.BlockWise1x128:
return 128
case _:
raise AssertionError(
f"Unsupported scaling type {scale_option} in get_tile_size"
)
def get_scaling_options(
mat_a: Any,
mat_b: Any,
scale_a_size: torch.Tensor,
scale_b_size: torch.Tensor,
) -> tuple[ScalingType, ScalingType]:
for scale_option_a, scale_option_b in scaling_pairs:
if is_desired_scaling(
mat_a, scale_a_size, scale_option_a
) and is_desired_scaling(mat_b, scale_b_size, scale_option_b, transpose=True):
return scale_option_a, scale_option_b
raise AssertionError(
f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
) # verify that shapes are supported by at least one existing pairing
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
def tuned_scaled_mm(
mat_a,
@ -1441,27 +1715,36 @@ def tuned_scaled_mm(
if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias:
scale_a_size, scale_b_size = scale_a_real.shape, scale_b_real.shape
for scale_option_a, scale_option_b in scaling_pairs:
if is_desired_scaling(
mat_a, scale_a_size, scale_option_a
) and is_desired_scaling(
mat_b, scale_b_size, scale_option_b, transpose=True
):
overriders["SCALE_RECIPE_A"] = scale_option_a.value
overriders["SCALE_RECIPE_B"] = scale_option_b.value
break
scale_option_a, scale_option_b = get_scaling_options(
mat_a, mat_b, scale_a_size, scale_b_size
)
overriders["SCALE_RECIPE_A"] = scale_option_a.value
overriders["SCALE_RECIPE_B"] = scale_option_b.value
if (
"SCALE_RECIPE_A" not in overriders
): # verify that shapes are supported by at least one existing pairing
raise AssertionError(
f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
scale_option_a in epilogue_scaling_types
and scale_option_b in epilogue_scaling_types
):
templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
overriders
)
elif (
scale_option_a in main_loop_scaling_types
and scale_option_b in main_loop_scaling_types
):
overriders["TILE_SIZE_A"] = get_tile_size(scale_option_a)
overriders["TILE_SIZE_B"] = get_tile_size(scale_option_b)
templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
overriders
)
templates_to_use.append(scaled_mm_device_tma_main_loop_scaling_template)
kwarg_overrides[scaled_mm_device_tma_main_loop_scaling_template.uid] = (
overriders
)
else:
raise AssertionError(
"Inductor Triton does not support scaling options that are present "
+ "in both epilogue scaling and main loop scaling"
)
if (
use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)

View File

@ -19,9 +19,12 @@ from .. import config, config as inductor_config
from ..kernel.bmm import bmm_template
from ..kernel.mm import (
blackwell_ws_persistent_device_tma_mm_template,
get_scaling_options,
get_tile_size,
mm_template,
persistent_tma_mm_template,
scaled_mm_device_tma_epilogue_scaling_template,
scaled_mm_device_tma_main_loop_scaling_template,
)
from ..kernel.mm_plus_mm import mm_plus_mm_template
from ..kernel_inputs import KernelInputs, MMKernelInputs
@ -1872,11 +1875,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
return False
def is_scalar_like(sz: Any) -> bool:
return (len(sz) == 0) or all(
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
)
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
@ -2140,6 +2138,72 @@ class CUDAScaledTMAEpilogueScalingTemplateConfigHeuristic(
self.mm_configs = self.scaled_persistent_mm_configs
@register_template_heuristic(
scaled_mm_device_tma_main_loop_scaling_template.uid,
"cuda",
register=torch.version.hip is None,
op_name="scaled_mm",
)
class CUDAScaledTMAMainLoopScalingTemplateConfigHeuristic(
ScaledTMAConfigMixin, CUDAConfigHeuristic
):
"""
Scaled TMA template heuristic for CUDA:
main loop scaling variants (BlockWise1x128, BlockWise1x32, BlockWise1x16, BlockWise128x128)
"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_persistent_mm_configs for TMA
self.mm_configs = self.scaled_persistent_mm_configs
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Generate main loop scaling kernel inputs.
"""
mat_a, mat_b, scale_a, scale_b = kernel_inputs._input_nodes
scale_a_size, scale_b_size = scale_a.get_size(), scale_b.get_size()
scale_option_a, scale_option_b = get_scaling_options(
mat_a, mat_b, scale_a_size, scale_b_size
)
tile_size_a = get_tile_size(scale_option_a)
tile_size_b = get_tile_size(scale_option_b)
# Get base scaled MM template configs from superclass
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
op_name,
):
# Add scaling-specific options for main loop scaling variants
# Inductor templates require compile-time constants passed in as tl.constexpr values.
# In cases in which the block size (BLOCK_*) is smaller than the tile size (128, 32, 16),
# scales must be broadcasted to BLOCK_* (rather than to a tile_sizextile_size chunk).
template_kwargs["TILE_SIZE_A"] = tile_size_a
template_kwargs["TILE_SIZE_B"] = tile_size_b
template_kwargs["MIN_BLOCK_TILE_AM"] = min(
template_kwargs["BLOCK_M"], tile_size_a
)
template_kwargs["MIN_BLOCK_TILE_AK"] = min(
template_kwargs["BLOCK_K"], tile_size_a
)
template_kwargs["MIN_BLOCK_TILE_BK"] = min(
template_kwargs["BLOCK_K"], tile_size_b
)
template_kwargs["MIN_BLOCK_TILE_BN"] = min(
template_kwargs["BLOCK_N"], tile_size_b
)
yield template_kwargs
@register_template_heuristic(
blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin
"cuda",

View File

@ -6366,7 +6366,10 @@ def meta_scaled_mm(
) or (
scale_a.dtype == torch.float8_e4m3fn
and scale_b.dtype == torch.float8_e4m3fn
)
) # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales)
def ceil_div(a, b):
return (a + b - 1) // b
if scale_a.numel() == 1 and scale_b.numel() == 1:
# tensorwise scaling
@ -6388,9 +6391,6 @@ def meta_scaled_mm(
block_size_mn = 128
def ceil_div(a, b):
return (a + b - 1) // b
num_k_blocks = ceil_div(_k, block_size_k)
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
@ -6444,6 +6444,20 @@ def meta_scaled_mm(
scale_a.is_contiguous() and scale_b.is_contiguous(),
lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
)
elif (
scale_a.size(0) == m
and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
and scale_b.size(1) == ceil_div(n, 128)
):
# (BlockWise1x128, BlockWise128x128)
pass # do nothing, but do not error
elif (
scale_a.size(0) == m
and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
and scale_b.size(1) == n
):
# (BlockWise1x128, BlockWise1x128)
pass # do nothing, but do not error
else:
# does not match any valid scaling type
torch._check(
@ -6452,6 +6466,10 @@ def meta_scaled_mm(
"Invalid scaling configuration. "
"For tensorwise scaling, both scales should be scalar. "
f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
+ f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). "
f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
+ f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
),

View File

@ -307,6 +307,16 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_blockwise(x: Tensor, float8_dtype: torch.dtype, block_outer: int, block_inner: int):
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
class MockGraphHandler(GraphLowering):
"""Minimal mock graph handler for testing virtualized context."""