Compare commits

...

2 Commits

Author SHA1 Message Date
a9226e6e41 [Inductor][Triton][FP8] Support tile-wise (1x128) scaling in Inductor (#165132)
Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
2025-10-27 13:51:05 -07:00
c3f066c7e4 [Inductor][Triton][FP8] Support deepseek-style scaling in Inductor (#164404)
Summary:

Support deepseek-style scaling in Inductor Triton for FP8 GEMMs. DeepSeek-style scaling is a colloquial term for a fine-grained mixed precision framework using FP8 to train [Deepseek-V3](https://arxiv.org/pdf/2412.19437), DeepSeek AI's recent MoE (Mixture of Experts) model. DeepSeek-style scaling effectively extends the dynamic range of FP8 by mitigating dequantization overhead under increased-precision accumulation, which is key to achieving more accurate FP8 GEMM results.

DeepSeek-style scaling on matmul `A @ B` leverages two different types of scaling strategies to preserve a balance between numerical stability and training efficiency:
- Activations (input tensor `A`): tile-wise (1x128 across shape `(M, K)`)
- Weights (input tensor `B`): block-wise (128x128 across shape `(N, K)`)

This diff enables Inductor users to replicate past successes with deepseek-style scaling and achieve higher numerical stability while increasing training efficiency.

NOTE: Block-wise 128x128 scaling is only supported in CUDA 12.9+; therefore, deepseek-style scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run deepseek-style scaling.

NOTE: Accuracy for FP8 is unstable, even with high tolerances, which is why TritonBench benchmarks are unlikely to be accurate against a `torch` implementation.

Test Plan:
Note that this command only works with the benchmark command (follow-up PR): in OSS PyTorch, run
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 4096 --n 768 --k 512 --output="{output_dir}/deepseek_bench.csv" --scaling-pair=BlockWise1x128,BlockWise128x128 --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/deepseek_style/deepseek_bench.log
```

Reviewed By: slayton58

Differential Revision: D83609850
2025-10-27 13:51:05 -07:00
5 changed files with 635 additions and 81 deletions

View File

@ -10,7 +10,9 @@ from torch._inductor import config, utils
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MX_GEMM,
)
@ -20,6 +22,7 @@ from torch.testing._internal.common_utils import (
parametrize,
)
from torch.testing._internal.inductor_utils import (
_quantize_blockwise,
_quantize_rowwise,
_quantize_tensorwise,
_to_fp8_saturated,
@ -623,8 +626,12 @@ class TestFP8Lowering(TestCase):
bias,
)
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 0").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 0").run(code[0])
FileCheck().check(
f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.TensorWise.value}"
).run(code[0])
FileCheck().check(
f"SCALE_RECIPE_B : tl.constexpr = {ScalingType.TensorWise.value}"
).run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
@ -769,8 +776,121 @@ class TestFP8Lowering(TestCase):
bias,
)
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 1").run(code[0])
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 1").run(code[0])
FileCheck().check(
f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.RowWise.value}"
).run(code[0])
FileCheck().check(
f"SCALE_RECIPE_B : tl.constexpr = {ScalingType.RowWise.value}"
).run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@unittest.skipIf(
_get_torch_cuda_version() < (12, 9),
"cuBLAS blockwise scaling added in CUDA 12.9",
)
@parametrize("shape", ((16, 256, 256), (1024, 512, 1024)))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"scaling_block_sizes", ((1, 128, 128, 128), (1, 128, 1, 128))
) # (BlockWise1x128, BlockWise128x128), (BlockWise1x128, BlockWise1x128)
def test_main_loop_scaling(
self,
shape: tuple[int, int, int],
use_fast_accum: bool,
scaling_block_sizes: tuple[int, int, int, int],
):
# Only bf16 output type is supported for non-tensorwise scaling, not fp32
dtype: torch.dtype = torch.bfloat16
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device)
M, N, K = shape # Matmul Y = X [M, K] x W [N, K]
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
am, ak, bn, bk = scaling_block_sizes
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_blockwise(
w, dtype_float8, block_outer=bn, block_inner=bk
)
w_t_fp8 = w_fp8.t()
if (bn, bk) == (1, 128):
w_inverse_scale = (
w_inverse_scale.t().contiguous().t().t()
) # 1x128 blocks need scales to be outer-dim-major
else:
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_blockwise(
x, dtype_float8, block_outer=am, block_inner=ak
)
if (am, ak) == (1, 128):
x_inverse_scale = (
x_inverse_scale.t().contiguous().t()
) # 1x128 blocks need scales to be outer-dim-major
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch(
{
"triton.enable_persistent_tma_matmul": True,
"test_configs.autotune_choice_name_regex": "triton_scaled_mm_device_tma",
"max_autotune_gemm_backends": "TRITON",
"max_autotune": True,
}
):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled, code = run_and_get_code(
linear_compiled,
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
# Verify that Inductor chooses the correct scaling recipes
FileCheck().check(
f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.BlockWise1x128.value}"
).run(code[0])
if (bn, bk) == (1, 128):
check_scale_recipe_b = ScalingType.BlockWise1x128.value
else:
check_scale_recipe_b = ScalingType.BlockWise128x128.value
FileCheck().check(
f"SCALE_RECIPE_B : tl.constexpr = {check_scale_recipe_b}"
).run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)

View File

@ -42,6 +42,7 @@ from ..select_algorithm import (
)
from ..utils import (
_use_cutlass_for_op,
ceildiv,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_ck_tile_gemm_template,
@ -579,6 +580,224 @@ scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate(
source=scaled_mm_device_tma_epilogue_scaling + load_scales + apply_scaling,
)
blockwise1xTILESIZE_scaling = r"""
@triton.jit
def blockwise1xTILESIZE_scaling(
pid,
scale,
ki,
lhs_size,
lhs_blocks,
k_blocks,
BLOCK_lhs: tl.constexpr,
BLOCK_K: tl.constexpr,
MIN_BLOCK_TILE_K: tl.constexpr,
TILE_SIZE: tl.constexpr,
):
row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs)
col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE)
ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :]
mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks)
scale_block = tl.load(ptrs, mask=mask, other=1.0)
scale_expanded = scale_block[:, :, None]
scale_expanded = tl.broadcast_to(
scale_expanded,
(BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K)
)
scale_expanded = scale_expanded.reshape(
BLOCK_lhs,
((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K
)
return scale_expanded
"""
blockwise128x128_scaling = r"""
@triton.jit
def blockwise128x128_scaling(
pid,
scale,
ki,
lhs_blocks,
k_blocks,
BLOCK_lhs: tl.constexpr,
BLOCK_K: tl.constexpr,
MIN_BLOCK_TILE_lhs: tl.constexpr,
MIN_BLOCK_TILE_K: tl.constexpr,
):
row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128)
col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128)
ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :]
mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks)
scale_block = tl.load(ptrs, mask=mask, other=1.0)
scale_expanded = scale_block[:, :, None, None]
scale_expanded = tl.broadcast_to(
scale_expanded,
((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K)
)
scale_expanded = scale_expanded.reshape(
((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs,
((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K
)
return scale_expanded
"""
scaled_mm_device_tma_main_loop_scaling = r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_bn = {{stride("B", 1)}}
start_pid = tl.program_id(axis=0).to(INDEX_DTYPE)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K],
strides=[stride_am, 1],
block_shape=[BLOCK_M, BLOCK_K],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[N, K],
strides=[stride_bn, 1],
block_shape=[BLOCK_N, BLOCK_K],
)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_M * num_pid_n
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A)
b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = ki * BLOCK_K
a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k])
am_blocks = tl.cdiv(M, TILE_SIZE_A)
ak_blocks = tl.cdiv(K, TILE_SIZE_A)
bn_blocks = tl.cdiv(N, TILE_SIZE_B)
bk_blocks = tl.cdiv(K, TILE_SIZE_B)
{%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128
scale_a_block = blockwise128x128_scaling(
pid_m,
a_scale,
ki,
am_blocks,
ak_blocks,
BLOCK_M,
BLOCK_K,
MIN_BLOCK_TILE_AM,
MIN_BLOCK_TILE_AK,
)
{%- else %} # ScalingType.Blockwise1xTILESIZE
scale_a_block = blockwise1xTILESIZE_scaling(
pid_m,
a_scale,
ki,
M,
am_blocks,
ak_blocks,
BLOCK_M,
BLOCK_K,
MIN_BLOCK_TILE_AK,
TILE_SIZE_A,
)
{%- endif %}
{%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128
scale_b_block = blockwise128x128_scaling(
pid_n,
b_scale,
ki,
bn_blocks,
bk_blocks,
BLOCK_N,
BLOCK_K,
MIN_BLOCK_TILE_BN,
MIN_BLOCK_TILE_BK,
)
{%- else %} # ScalingType.Blockwise1xTILESIZE
scale_b_block = blockwise1xTILESIZE_scaling(
pid_n,
b_scale,
ki,
N,
bn_blocks,
bk_blocks,
BLOCK_N,
BLOCK_K,
MIN_BLOCK_TILE_BK,
TILE_SIZE_B,
)
{%- endif %}
a_scaled = a * scale_a_block
b_scaled = b * scale_b_block
accumulator = tl.dot(a_scaled, b_scaled.T, accumulator)
if ki == k_tiles - 1:
offs_cm = offs_am + tl.arange(0, BLOCK_M)
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
# inductor generates a suffix
{{store_output(
("offs_am", "offs_bn"),
"accumulator",
indent_width=12,
val_shape=("BLOCK_M", "BLOCK_N"),
block_indexing=True,
)}}
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
"""
scaled_mm_device_tma_main_loop_scaling_template = TritonTemplate(
name="scaled_mm_device_tma_main_loop_scaling",
grid=persistent_mm_grid,
source=scaled_mm_device_tma_main_loop_scaling
+ load_scales
+ blockwise1xTILESIZE_scaling
+ blockwise128x128_scaling,
)
_compute_blackwell_pid = r"""
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr):
@ -1325,9 +1544,15 @@ def tuned_sparse_semi_structured_mm(
scaling_pairs = [
(ScalingType.TensorWise, ScalingType.TensorWise),
(ScalingType.RowWise, ScalingType.RowWise),
(ScalingType.BlockWise1x128, ScalingType.BlockWise128x128),
(ScalingType.BlockWise1x128, ScalingType.BlockWise1x128),
]
epilogue_scaling_types = [ScalingType.TensorWise, ScalingType.RowWise]
main_loop_scaling_types = [ScalingType.BlockWise1x128, ScalingType.BlockWise128x128]
def _is_tensorwise_scaling(sz: Any) -> bool:
return (len(sz) == 0) or all(
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
@ -1339,8 +1564,26 @@ def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
return V.graph.sizevars.statically_known_equals(sz[idx], 1)
def _is_blockwise1xTILESIZE_scaling(
sz: Any, tensor_sz: Any, tile_size: int, transpose: bool
) -> bool:
lhs = 1 if transpose else 0
rhs = 0 if transpose else 1
return V.graph.sizevars.statically_known_equals(
sz[lhs], tensor_sz[lhs]
) and V.graph.sizevars.statically_known_equals(
sz[rhs], ceildiv(tensor_sz[rhs], tile_size)
)
def _is_blockwise128x128_scaling(sz: Any, tensor_sz: Any) -> bool:
return V.graph.sizevars.statically_known_equals(
sz[0], ceildiv(tensor_sz[0], 128)
) and V.graph.sizevars.statically_known_equals(sz[1], ceildiv(tensor_sz[1], 128))
def is_desired_scaling(
t: torch.Tensor,
t: Any,
scale_size: torch.Tensor,
scaling_type: ScalingType,
transpose: bool = False,
@ -1350,10 +1593,45 @@ def is_desired_scaling(
return _is_tensorwise_scaling(scale_size)
case ScalingType.RowWise:
return _is_rowwise_scaling(scale_size, transpose)
case ScalingType.BlockWise1x128:
return _is_blockwise1xTILESIZE_scaling(
scale_size, t.get_size(), 128, transpose
)
case ScalingType.BlockWise128x128:
return _is_blockwise128x128_scaling(scale_size, t.get_size())
case _:
raise AssertionError(f"Unsupported scaling type {scaling_type}")
def get_tile_size(scale_option) -> int:
match scale_option:
case ScalingType.BlockWise128x128:
return 128
case ScalingType.BlockWise1x128:
return 128
case _:
raise AssertionError(
f"Unsupported scaling type {scale_option} in get_tile_size"
)
def get_scaling_options(
mat_a: Any,
mat_b: Any,
scale_a_size: torch.Tensor,
scale_b_size: torch.Tensor,
) -> tuple[ScalingType, ScalingType]:
for scale_option_a, scale_option_b in scaling_pairs:
if is_desired_scaling(
mat_a, scale_a_size, scale_option_a
) and is_desired_scaling(mat_b, scale_b_size, scale_option_b, transpose=True):
return scale_option_a, scale_option_b
raise AssertionError(
f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
) # verify that shapes are supported by at least one existing pairing
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
def tuned_scaled_mm(
mat_a,
@ -1441,27 +1719,36 @@ def tuned_scaled_mm(
if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias:
scale_a_size, scale_b_size = scale_a_real.shape, scale_b_real.shape
for scale_option_a, scale_option_b in scaling_pairs:
if is_desired_scaling(
mat_a, scale_a_size, scale_option_a
) and is_desired_scaling(
mat_b, scale_b_size, scale_option_b, transpose=True
):
overriders["SCALE_RECIPE_A"] = scale_option_a.value
overriders["SCALE_RECIPE_B"] = scale_option_b.value
break
scale_option_a, scale_option_b = get_scaling_options(
mat_a, mat_b, scale_a_size, scale_b_size
)
overriders["SCALE_RECIPE_A"] = scale_option_a.value
overriders["SCALE_RECIPE_B"] = scale_option_b.value
if (
"SCALE_RECIPE_A" not in overriders
): # verify that shapes are supported by at least one existing pairing
raise AssertionError(
f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}"
scale_option_a in epilogue_scaling_types
and scale_option_b in epilogue_scaling_types
):
templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
overriders
)
elif (
scale_option_a in main_loop_scaling_types
and scale_option_b in main_loop_scaling_types
):
overriders["TILE_SIZE_A"] = get_tile_size(scale_option_a)
overriders["TILE_SIZE_B"] = get_tile_size(scale_option_b)
templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template)
kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = (
overriders
)
templates_to_use.append(scaled_mm_device_tma_main_loop_scaling_template)
kwarg_overrides[scaled_mm_device_tma_main_loop_scaling_template.uid] = (
overriders
)
else:
raise AssertionError(
"Inductor Triton does not support scaling options that are present "
+ "in both epilogue scaling and main loop scaling"
)
if (
use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)

View File

@ -19,9 +19,12 @@ from .. import config, config as inductor_config
from ..kernel.bmm import bmm_template
from ..kernel.mm import (
blackwell_ws_persistent_device_tma_mm_template,
get_scaling_options,
get_tile_size,
mm_template,
persistent_tma_mm_template,
scaled_mm_device_tma_epilogue_scaling_template,
scaled_mm_device_tma_main_loop_scaling_template,
)
from ..kernel.mm_plus_mm import mm_plus_mm_template
from ..kernel_inputs import KernelInputs, MMKernelInputs
@ -1886,11 +1889,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
return False
def is_scalar_like(sz: Any) -> bool:
return (len(sz) == 0) or all(
V.graph.sizevars.statically_known_equals(d, 1) for d in sz
)
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
@ -2154,6 +2152,72 @@ class CUDAScaledTMAEpilogueScalingTemplateConfigHeuristic(
self.mm_configs = self.scaled_persistent_mm_configs
@register_template_heuristic(
scaled_mm_device_tma_main_loop_scaling_template.uid,
"cuda",
register=torch.version.hip is None,
op_name="scaled_mm",
)
class CUDAScaledTMAMainLoopScalingTemplateConfigHeuristic(
ScaledTMAConfigMixin, CUDAConfigHeuristic
):
"""
Scaled TMA template heuristic for CUDA:
main loop scaling variants (BlockWise1x128, BlockWise1x32, BlockWise1x16, BlockWise128x128)
"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_persistent_mm_configs for TMA
self.mm_configs = self.scaled_persistent_mm_configs
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Generate main loop scaling kernel inputs.
"""
mat_a, mat_b, scale_a, scale_b = kernel_inputs._input_nodes
scale_a_size, scale_b_size = scale_a.get_size(), scale_b.get_size()
scale_option_a, scale_option_b = get_scaling_options(
mat_a, mat_b, scale_a_size, scale_b_size
)
tile_size_a = get_tile_size(scale_option_a)
tile_size_b = get_tile_size(scale_option_b)
# Get base scaled MM template configs from superclass
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
op_name,
):
# Add scaling-specific options for main loop scaling variants
# Inductor templates require compile-time constants passed in as tl.constexpr values.
# In cases in which the block size (BLOCK_*) is smaller than the tile size (128, 32, 16),
# scales must be broadcasted to BLOCK_* (rather than to a tile_sizextile_size chunk).
template_kwargs["TILE_SIZE_A"] = tile_size_a
template_kwargs["TILE_SIZE_B"] = tile_size_b
template_kwargs["MIN_BLOCK_TILE_AM"] = min(
template_kwargs["BLOCK_M"], tile_size_a
)
template_kwargs["MIN_BLOCK_TILE_AK"] = min(
template_kwargs["BLOCK_K"], tile_size_a
)
template_kwargs["MIN_BLOCK_TILE_BK"] = min(
template_kwargs["BLOCK_K"], tile_size_b
)
template_kwargs["MIN_BLOCK_TILE_BN"] = min(
template_kwargs["BLOCK_N"], tile_size_b
)
yield template_kwargs
@register_template_heuristic(
blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin
"cuda",

View File

@ -6365,12 +6365,18 @@ def meta_scaled_mm(
n = mat2.size(1)
is_blockwise_scaling = (
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
) or (
scale_a.dtype == torch.float8_e4m3fn
and scale_b.dtype == torch.float8_e4m3fn
)
(
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
)
or (
scale_a.dtype == torch.float8_e4m3fn
and scale_b.dtype == torch.float8_e4m3fn
)
) # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales)
def ceil_div(a, b):
return (a + b - 1) // b
if scale_a.numel() == 1 and scale_b.numel() == 1:
# tensorwise scaling
@ -6392,9 +6398,6 @@ def meta_scaled_mm(
block_size_mn = 128
def ceil_div(a, b):
return (a + b - 1) // b
num_k_blocks = ceil_div(_k, block_size_k)
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
@ -6448,6 +6451,20 @@ def meta_scaled_mm(
scale_a.is_contiguous() and scale_b.is_contiguous(),
lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
)
elif (
scale_a.size(0) == m
and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
and scale_b.size(1) == ceil_div(n, 128)
):
# (BlockWise1x128, BlockWise128x128)
pass # do nothing, but do not error
elif (
scale_a.size(0) == m
and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
and scale_b.size(1) == n
):
# (BlockWise1x128, BlockWise1x128)
pass # do nothing, but do not error
else:
# does not match any valid scaling type
torch._check(
@ -6456,6 +6473,10 @@ def meta_scaled_mm(
"Invalid scaling configuration. "
"For tensorwise scaling, both scales should be scalar. "
f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
+ f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). "
f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
+ f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
),

View File

@ -1,50 +1,55 @@
# mypy: ignore-errors
import logging
import torch
import re
import unittest
import functools
import contextlib
import functools
import logging
import os
from subprocess import CalledProcessError
import re
import sys
import unittest
from subprocess import CalledProcessError
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch.fx.experimental.proxy_tensor import make_fx
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.utils import OrderedSet
from torch._inductor.codecache import CppCodeCache
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.codegen.common import (
get_custom_backend_config_for_device,
get_custom_backend_pass_for_device,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
init_backend_registration,
register_backend_for_device
register_backend_for_device,
)
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
from torch.utils._helion import has_helion
from torch.utils._triton import has_triton
from torch.utils._config_module import ConfigModule
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.graph import GraphLowering
from torch._inductor.utils import (
get_gpu_shared_memory,
get_gpu_type,
GPU_TYPES,
is_big_gpu,
is_gpu,
OrderedSet,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_device_type import (
get_desired_device_type_test_bases,
)
from torch.testing._internal.common_utils import (
LazyVal,
IS_FBCODE,
)
from torch.testing._internal.common_utils import (
TestCase,
IS_CI,
IS_FBCODE,
IS_WINDOWS,
LazyVal,
TestCase,
)
from torch.utils._config_module import ConfigModule
from torch.utils._helion import has_helion
from torch.utils._triton import has_triton
log: logging.Logger = logging.getLogger(__name__)
def test_cpu():
try:
CppCodeCache.load("")
@ -57,6 +62,7 @@ def test_cpu():
):
return False
HAS_CPU = LazyVal(test_cpu)
HAS_TRITON = has_triton()
@ -65,6 +71,7 @@ HAS_HELION = has_helion()
if HAS_TRITON:
import triton
TRITON_HAS_CPU = "cpu" in triton.backends.backends
else:
TRITON_HAS_CPU = False
@ -86,16 +93,15 @@ HAS_MULTIGPU = any(
)
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
RUN_GPU = (
HAS_GPU
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
RUN_GPU = HAS_GPU and any(
is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases
)
RUN_CPU = (
HAS_CPU
and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
RUN_CPU = HAS_CPU and any(
getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases
)
def _check_has_dynamic_shape(
self: TestCase,
code,
@ -117,24 +123,31 @@ def _check_has_dynamic_shape(
def skipDeviceIf(cond, msg, *, device):
if cond:
def decorate_fn(fn):
@functools.wraps(fn)
def inner(self, *args, **kwargs):
if not hasattr(self, "device"):
warn_msg = "Expect the test class to have attribute device but not found. "
warn_msg = (
"Expect the test class to have attribute device but not found. "
)
if hasattr(self, "device_type"):
warn_msg += "Consider using the skip device decorators in common_device_type.py"
log.warning(warn_msg)
if self.device == device:
raise unittest.SkipTest(msg)
return fn(self, *args, **kwargs)
return inner
else:
def decorate_fn(fn):
return fn
return decorate_fn
def skip_windows_ci(name: str, file: str) -> None:
if IS_WINDOWS and IS_CI:
module = os.path.basename(file).strip(".py")
@ -145,36 +158,41 @@ def skip_windows_ci(name: str, file: str) -> None:
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS
requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu")
requires_gpu = functools.partial(
unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu"
)
requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion")
def requires_cuda_with_enough_memory(min_mem_required):
def inner(fn):
if not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < min_mem_required:
return unittest.skip(f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe")(fn)
if (
not torch.cuda.is_available()
or torch.cuda.get_device_properties().total_memory < min_mem_required
):
return unittest.skip(
f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe"
)(fn)
else:
return fn
return inner
skipCUDAIf = functools.partial(skipDeviceIf, device="cuda")
skipXPUIf = functools.partial(skipDeviceIf, device="xpu")
skipCPUIf = functools.partial(skipDeviceIf, device="cpu")
IS_A100 = LazyVal(
lambda: HAS_CUDA_AND_TRITON
and get_gpu_shared_memory() == 166912
)
IS_A100 = LazyVal(lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 166912)
IS_H100 = LazyVal(
lambda: HAS_CUDA_AND_TRITON
and get_gpu_shared_memory() == 232448
)
IS_H100 = LazyVal(lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 232448)
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA_AND_TRITON and is_big_gpu())
def dummy_graph() -> GraphLowering:
"""
Create a graph. This is useful for unit testing code which accesses
@ -190,6 +208,7 @@ def dummy_graph() -> GraphLowering:
return graph
def maybe_skip_size_asserts(op):
"""
For certain ops, there meta and eager implementation returns different
@ -226,12 +245,19 @@ def maybe_skip_size_asserts(op):
else:
return contextlib.nullcontext()
def get_func_call() -> str:
return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call("
return (
"void inductor_entry_impl("
if torch._inductor.config.cpp_wrapper
else "def call("
)
def get_kernel_launch() -> str:
return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run("
def clone_preserve_strides_offset(x, device=None):
if not isinstance(x, torch.Tensor):
return x
@ -245,6 +271,7 @@ def clone_preserve_strides_offset(x, device=None):
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
@ -256,6 +283,7 @@ EPS: float = 1e-12
Tensor = torch.Tensor
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
@ -275,6 +303,7 @@ def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)
@torch.no_grad()
def _amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@ -293,6 +322,7 @@ def _amax_to_scale(
res = torch.clamp(res, max=FP16_MAX_POS)
return res
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x))
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
@ -300,6 +330,7 @@ def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
@ -307,6 +338,28 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_blockwise(
x: Tensor, float8_dtype: torch.dtype, block_outer: int, block_inner: int
):
min_outer = min(block_outer, x.shape[0])
min_inner = min(block_inner, x.shape[1])
x = x.unflatten(1, (-1, min_inner)).unflatten(0, (-1, min_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
scale_expanded = scale.repeat_interleave(min_outer, dim=0).repeat_interleave(
min_inner, dim=1
)
x_fp8 = _to_fp8_saturated(
x / scale_expanded, # Ensures that scaling doesn't cause inf/nan values
float8_dtype,
)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
class MockGraphHandler(GraphLowering):
"""Minimal mock graph handler for testing virtualized context."""
@ -325,12 +378,13 @@ class MockGraphHandler(GraphLowering):
"""Return default dtype for any buffer (for testing)."""
return torch.float32
@contextlib.contextmanager
def patch_inductor_backend(
device: str,
python_wrapper_codegen: PythonWrapperCodegen = None,
custom_pass: CustomGraphModulePass = None,
custom_backend_config: ConfigModule = None
custom_backend_config: ConfigModule = None,
):
"""
Patch the inductor backend for a specific device.
@ -351,11 +405,19 @@ def patch_inductor_backend(
register_backend_for_device(
device,
original_scheduling,
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
(
python_wrapper_codegen
if python_wrapper_codegen is not None
else original_python_wrapper
),
original_cpp_wrapper,
original_fx_wrapper,
custom_pass if custom_pass is not None else original_custom_pass,
custom_backend_config if custom_backend_config is not None else original_custom_backend_config
(
custom_backend_config
if custom_backend_config is not None
else original_custom_backend_config
),
)
yield
finally:
@ -367,5 +429,5 @@ def patch_inductor_backend(
original_cpp_wrapper,
original_fx_wrapper,
original_custom_pass,
original_custom_backend_config
original_custom_backend_config,
)