[Cutlass] Support float8_e4m3fn GEMM (#153890)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153890
Approved by: https://github.com/drisspg, https://github.com/eellison
This commit is contained in:
Michael Lazos
2025-05-21 22:13:51 -07:00
committed by PyTorch MergeBot
parent c1b7dbc52a
commit 423fc671e9
8 changed files with 326 additions and 103 deletions

View File

@ -41,14 +41,23 @@ from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FP8,
SM80OrLater,
SM90OrLater,
)
from torch.testing._internal.common_utils import (
IN_RE_WORKER,
instantiate_parametrized_tests,
IS_FBCODE,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.inductor_utils import (
_quantize_rowwise,
_quantize_tensorwise,
HAS_CPU,
HAS_CUDA,
)
torch.set_float32_matmul_precision("high")
@ -127,6 +136,17 @@ use_evt_config = config.patch(
}
)
fp8_config = config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 1,
"autotune_fallback_to_aten": False,
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
"cuda.cutlass_tma_only": True,
}
)
@instantiate_parametrized_tests
class TestCutlassBackend(TestCase):
@ -1643,6 +1663,155 @@ class TestCutlassBackend(TestCase):
self.assertGreater(count, 1000, "Too few ops generated")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@fp8_config
@parametrize("float8_dtype", (torch.float8_e4m3fn,))
@parametrize(
"shape",
(
(
16,
16,
32,
),
),
)
@parametrize("has_bias", (False,))
@parametrize("use_fast_accum", (False,))
def test_fp8_rowwise_scaling(
self,
float8_dtype: torch.dtype,
shape: tuple[int, int, int],
has_bias: bool,
use_fast_accum: bool,
):
# Only bf16 output type is supported for row-wise scaling, not fp32
output_dtype: torch.dtype = torch.bfloat16
device = "cuda"
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
x = torch.randn(M, K, dtype=output_dtype, device=device)
w = torch.randn(N, K, dtype=output_dtype, device=device)
bias = None
if has_bias:
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_rowwise(w, float8_dtype)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_rowwise(x, float8_dtype)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
linear_compiled = torch.compile(linear, backend="inductor")
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, output_dtype)
self.assertEqual(y_compiled.dtype, output_dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@fp8_config
@parametrize("float8_dtype", (torch.float8_e4m3fn,))
@parametrize(
"shape",
(
(
16,
16,
32,
),
),
)
@parametrize("has_bias", (False,))
@parametrize("use_fast_accum", (False,))
def test_fp8_tensorwise_scaling(
self,
float8_dtype: torch.dtype,
shape: tuple[int, int, int],
has_bias: bool,
use_fast_accum: bool,
):
device = "cuda"
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
input_dtype = torch.bfloat16
output_dtype = torch.bfloat16
# input and output dtypes of _scaled_mm do not need to be the same, but
# typically in a model they are
x = torch.randn(M, K, dtype=input_dtype, device=device)
w = torch.randn(N, K, dtype=input_dtype, device=device)
bias = None
if has_bias:
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_tensorwise(w, float8_dtype)
w_t_fp8 = w_fp8.t()
# quantize input x
x_fp8, x_inverse_scale = _quantize_tensorwise(x, float8_dtype)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, output_dtype)
self.assertEqual(y_compiled.dtype, output_dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
# autotuning for the compiled case, the results can be different because of
# the way blocks of results are accumulated (float addition not associative), so
# setting a small absolute tolerance in these tests
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -17,7 +17,13 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.inductor_utils import (
_quantize_rowwise,
_quantize_tensorwise,
_to_fp8_saturated,
HAS_CPU,
HAS_CUDA,
)
from torch.utils._triton import has_triton_tma_device
@ -26,70 +32,6 @@ torch.set_float32_matmul_precision("high")
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS: float = 1e-12
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
else:
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)
@torch.no_grad()
def _amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
) -> torch.Tensor:
# To make scale dtype to be fp32 for accuracy
amax = amax.float()
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
return res
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x))
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _fix_fp8_dtype_for_rocm(
dtype: Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]], device

View File

@ -284,6 +284,7 @@ class CUTLASSTemplate(CUDATemplate):
torch.uint8: "uint8_t",
torch.bool: "bool",
torch.bfloat16: "cutlass::bfloat16_t",
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
}
_DTYPE_TO_CUTLASS_SPARSE_META = {

View File

@ -20,6 +20,20 @@ from ...virtualized import V
_ACCUMULATOR_ARG_NAME = "accum"
def scaled_mm_evt(
scale_A_name: str, scale_B_name: str, output_name: str
) -> tuple[list[str], dict[str, Any], str]:
evt_read_names = [scale_A_name, scale_B_name]
var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]}
var_name_to_buffer_name["D"] = output_name
var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name
evt_py_code = f"def fn(accum, {scale_A_name}, {scale_B_name}):{linesep}\
D = accum * {scale_A_name} * {scale_B_name}{linesep}\
return D{linesep}"
return evt_read_names, var_name_to_buffer_name, evt_py_code
class CutlassEVTOpsMixIn:
@staticmethod
def _infix_bin_op(op: str, a: str, b: str) -> str:

View File

@ -314,6 +314,7 @@ DTYPE_TO_CUTLASS_TYPE = {
**DTYPE_TO_CPP,
torch.float16: "__half",
torch.bfloat16: "__nv_bfloat16",
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
}
@ -359,6 +360,8 @@ def dtype_match(
return cutlass_dtype == cutlass_library.library.DataType.u8
elif torch_dtype == torch.int32:
return cutlass_dtype == cutlass_library.library.DataType.s32
elif torch_dtype == torch.float8_e4m3fn:
return cutlass_dtype == cutlass_library.library.DataType.e4m3
else:
return False
@ -389,7 +392,7 @@ def get_accumulator_dtype(
]:
torch_dtype = dtype0
if torch_dtype in (torch.float16, torch.bfloat16, torch.float):
if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn):
return torch.float
if torch_dtype == torch.int8:
return torch.int32
@ -407,7 +410,7 @@ def get_alignments(torch_dtype: torch.dtype) -> list[int]:
return [8, 4, 2, 1]
elif torch_dtype == torch.float:
return [4, 2, 1]
elif torch_dtype in (torch.uint8, torch.int8):
elif torch_dtype in (torch.uint8, torch.int8, torch.float8_e4m3fn):
return [16, 8, 4, 2]
elif torch_dtype == torch.int32:
return [4, 2, 1]

View File

@ -31,7 +31,7 @@ from . import cutlass_utils
from .cuda_kernel import CUDATemplateKernel
from .cuda_template import CUTLASSTemplate
from .cutlass_presets import gen_cutlass_presets
from .cutlass_python_evt import CutlassEVTCodegen
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
from .cutlass_utils import torch_dtype_to_cutlass_type
@ -435,7 +435,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
)
self.alpha = alpha
self.beta = beta
assert len(input_nodes) == 2 or len(input_nodes) == 3
assert len(input_nodes) == 2 or len(input_nodes) == 3 or len(input_nodes) == 4
assert self._are_inputs_layout_compatible(
[node.get_layout() for node in input_nodes]
)
@ -1064,28 +1064,54 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op = self.swap_XW(op)
should_swap_xw = True
if epilogue_nodes:
(
evt_read_names,
evt_write_names,
var_name_to_buffer_name,
evt_py_code,
) = CutlassEVTCodegen.ir_to_evt_python_code(
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
)
read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names)
write_names = OrderedSet(evt_write_names)
assert write_names, "There should be at least one write"
D_output_name = var_name_to_buffer_name["D"]
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
D_output_buffer = name_to_buffer[D_output_name]
Y = D_output_buffer # type: ignore[assignment]
# Interestingly, I don't think the rest of the layout matters here since we
# use the properties of the Y buffer to fill in D's properties in the epilogue
# args. This is needed though because it defines types expected in the epilogue args.
op.D.element = cutlass_utils.torch_dtype_to_cutlass_type(
D_output_buffer.get_layout().dtype
)
is_scaled_mm = len(self.input_nodes) == 4
if epilogue_nodes or is_scaled_mm:
if epilogue_nodes:
(
evt_read_names,
evt_write_names,
var_name_to_buffer_name,
evt_py_code,
) = CutlassEVTCodegen.ir_to_evt_python_code(
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
)
D_output_name = var_name_to_buffer_name["D"]
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
D_output_buffer = name_to_buffer[D_output_name]
D_dtype = D_output_buffer.get_dtype()
Y = D_output_buffer # type: ignore[assignment]
# Interestingly, I don't think the rest of the layout matters here since we
# use the properties of the Y buffer to fill in D's properties in the epilogue
# args. This is needed though because it defines types expected in the epilogue args.
op.D.element = cutlass_utils.torch_dtype_to_cutlass_type(
D_output_buffer.get_layout().dtype
)
read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names)
write_names = OrderedSet(evt_write_names)
assert write_names, "There should be at least one write"
input_names = list(read_names)
output_names = list(write_names)
epilogue_inputs = [name_to_buffer[name] for name in input_names]
epilogue_outputs = [name_to_buffer[name] for name in output_names]
else: # Scaled MM, we read the two scale matrices and write a single output
(
evt_read_names,
var_name_to_buffer_name,
evt_py_code,
) = scaled_mm_evt(
self.input_nodes[2].get_name(),
self.input_nodes[3].get_name(),
Y.get_name(),
)
input_names = list(evt_read_names)
output_names = [] # We only need Y
D_dtype = Y.get_layout().dtype
epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]]
epilogue_outputs = []
acc_dtype = cutlass_utils.get_accumulator_dtype(
[X.get_dtype(), W.get_dtype()]
)
@ -1095,24 +1121,20 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op,
evt_py_code,
var_name_to_buffer_name,
D_output_buffer.get_dtype(),
D_dtype,
acc_dtype,
)
input_names = list(read_names)
output_names = list(write_names)
epilogue_inputs = [name_to_buffer[name] for name in input_names]
epilogue_outputs = [name_to_buffer[name] for name in output_names]
inputs = [
X,
W,
Bias,
Y,
*epilogue_inputs, # type: ignore[list-item]
Y,
*extra_inputs,
]
names_str = ",".join(
["X", "W", "Bias", "Y", *input_names, *output_names, *extra_names]
["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names]
)
else:
evt_name = None
@ -1286,7 +1308,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
Returns:
bool: True if layouts are GEMM compatible, otherwise False.
"""
assert len(layouts) == 2 or len(layouts) == 3
assert len(layouts) == 2 or len(layouts) == 3 or len(layouts) == 4
# Check if A and B are compatible
A_layout, B_layout = layouts[:2]
if len(A_layout.size) < 1:
@ -1354,6 +1376,8 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
# handle the fake output buffer during lowering
name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment]
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
@ -1395,7 +1419,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
) -> bool:
import cutlass_library.library as cutlass_lib
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
has_bias = len(self.input_nodes) == 3 and self.input_nodes[2] is not None
if has_bias:
Bias = self.input_nodes[2]
# bias dtype
@ -1464,7 +1488,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
self,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]:
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
Bias = None if len(self.input_nodes) in (2, 4) else self.input_nodes[2]
inputs: list[Optional[Buffer]] = []
names: list[str] = []
return (Bias, inputs, names)

View File

@ -1187,6 +1187,14 @@ def tuned_scaled_mm(
epilogue_fn=scale_mm_epilogue(),
)
if is_nonzero and use_cutlass_template(layout, m, n, k):
if use_fast_accum:
log.warning(
"use_fast_accum=True is not supported by cutlass template, skipping cutlass choices"
)
else:
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, input_nodes) # type: ignore[arg-type]
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)

View File

@ -228,3 +228,65 @@ def clone_preserve_strides_offset(x, device=None):
buffer = buffer.to(device, copy=True)
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS: float = 1e-12
Tensor = torch.Tensor
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
else:
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)
@torch.no_grad()
def _amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
) -> torch.Tensor:
# To make scale dtype to be fp32 for accuracy
amax = amax.float()
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
return res
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x))
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale