mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Cutlass] Support float8_e4m3fn GEMM (#153890)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153890 Approved by: https://github.com/drisspg, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1b7dbc52a
commit
423fc671e9
@ -41,14 +41,23 @@ from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
SM80OrLater,
|
||||
SM90OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
IN_RE_WORKER,
|
||||
instantiate_parametrized_tests,
|
||||
IS_FBCODE,
|
||||
parametrize,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
_quantize_rowwise,
|
||||
_quantize_tensorwise,
|
||||
HAS_CPU,
|
||||
HAS_CUDA,
|
||||
)
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
@ -127,6 +136,17 @@ use_evt_config = config.patch(
|
||||
}
|
||||
)
|
||||
|
||||
fp8_config = config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"autotune_fallback_to_aten": False,
|
||||
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
|
||||
"cuda.cutlass_tma_only": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestCutlassBackend(TestCase):
|
||||
@ -1643,6 +1663,155 @@ class TestCutlassBackend(TestCase):
|
||||
|
||||
self.assertGreater(count, 1000, "Too few ops generated")
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+")
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@fp8_config
|
||||
@parametrize("float8_dtype", (torch.float8_e4m3fn,))
|
||||
@parametrize(
|
||||
"shape",
|
||||
(
|
||||
(
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
),
|
||||
),
|
||||
)
|
||||
@parametrize("has_bias", (False,))
|
||||
@parametrize("use_fast_accum", (False,))
|
||||
def test_fp8_rowwise_scaling(
|
||||
self,
|
||||
float8_dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
has_bias: bool,
|
||||
use_fast_accum: bool,
|
||||
):
|
||||
# Only bf16 output type is supported for row-wise scaling, not fp32
|
||||
output_dtype: torch.dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
|
||||
x = torch.randn(M, K, dtype=output_dtype, device=device)
|
||||
w = torch.randn(N, K, dtype=output_dtype, device=device)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
|
||||
|
||||
# quantize weight (prior to inference)
|
||||
w_fp8, w_inverse_scale = _quantize_rowwise(w, float8_dtype)
|
||||
w_t_fp8 = w_fp8.t()
|
||||
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
|
||||
|
||||
# quantize input x
|
||||
x_fp8, x_inverse_scale = _quantize_rowwise(x, float8_dtype)
|
||||
|
||||
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
|
||||
y = torch._scaled_mm(
|
||||
x_fp8,
|
||||
w_t_fp8,
|
||||
x_inverse_scale,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
out_dtype=output_dtype,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
return y
|
||||
|
||||
y_eager = linear(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
linear_compiled = torch.compile(linear, backend="inductor")
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, output_dtype)
|
||||
self.assertEqual(y_compiled.dtype, output_dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+")
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@fp8_config
|
||||
@parametrize("float8_dtype", (torch.float8_e4m3fn,))
|
||||
@parametrize(
|
||||
"shape",
|
||||
(
|
||||
(
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
),
|
||||
),
|
||||
)
|
||||
@parametrize("has_bias", (False,))
|
||||
@parametrize("use_fast_accum", (False,))
|
||||
def test_fp8_tensorwise_scaling(
|
||||
self,
|
||||
float8_dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
has_bias: bool,
|
||||
use_fast_accum: bool,
|
||||
):
|
||||
device = "cuda"
|
||||
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
|
||||
input_dtype = torch.bfloat16
|
||||
output_dtype = torch.bfloat16
|
||||
# input and output dtypes of _scaled_mm do not need to be the same, but
|
||||
# typically in a model they are
|
||||
x = torch.randn(M, K, dtype=input_dtype, device=device)
|
||||
w = torch.randn(N, K, dtype=input_dtype, device=device)
|
||||
bias = None
|
||||
if has_bias:
|
||||
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
|
||||
|
||||
# quantize weight (prior to inference)
|
||||
w_fp8, w_inverse_scale = _quantize_tensorwise(w, float8_dtype)
|
||||
w_t_fp8 = w_fp8.t()
|
||||
|
||||
# quantize input x
|
||||
x_fp8, x_inverse_scale = _quantize_tensorwise(x, float8_dtype)
|
||||
|
||||
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
|
||||
y = torch._scaled_mm(
|
||||
x_fp8,
|
||||
w_t_fp8,
|
||||
x_inverse_scale,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
out_dtype=output_dtype,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
return y
|
||||
|
||||
y_eager = linear(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, output_dtype)
|
||||
self.assertEqual(y_compiled.dtype, output_dtype)
|
||||
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
|
||||
# autotuning for the compiled case, the results can be different because of
|
||||
# the way blocks of results are accumulated (float addition not associative), so
|
||||
# setting a small absolute tolerance in these tests
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
@ -17,7 +17,13 @@ from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
_quantize_rowwise,
|
||||
_quantize_tensorwise,
|
||||
_to_fp8_saturated,
|
||||
HAS_CPU,
|
||||
HAS_CUDA,
|
||||
)
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
|
||||
@ -26,70 +32,6 @@ torch.set_float32_matmul_precision("high")
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
|
||||
# define the e4m3/e5m2 constants
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||
|
||||
FP16_MAX_POS: float = torch.finfo(torch.float16).max
|
||||
EPS: float = 1e-12
|
||||
|
||||
|
||||
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
|
||||
# The default behavior in PyTorch for casting to `float8_e4m3fn`
|
||||
# and `e5m2` is to not saturate. In this context, we should saturate.
|
||||
# A common case where we want to saturate is when the history of a
|
||||
# tensor has a maximum value of `amax1`, and the current amax value
|
||||
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
|
||||
# scaling.
|
||||
if float8_dtype == torch.float8_e4m3fn:
|
||||
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
elif float8_dtype == torch.float8_e5m2:
|
||||
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
||||
elif float8_dtype == torch.float8_e4m3fnuz:
|
||||
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
|
||||
elif float8_dtype == torch.float8_e5m2fnuz:
|
||||
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
|
||||
else:
|
||||
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
return x.to(float8_dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _amax_to_scale(
|
||||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
# To make scale dtype to be fp32 for accuracy
|
||||
amax = amax.float()
|
||||
if float8_dtype == torch.float8_e4m3fn:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else: # e5m2
|
||||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
|
||||
# Ensure that the scale is representable in float16,
|
||||
# this helps when amax is small. We are assuming that we don't need
|
||||
# to care about this for float32/bfloat16.
|
||||
if orig_dtype is torch.float16:
|
||||
res = torch.clamp(res, max=FP16_MAX_POS)
|
||||
return res
|
||||
|
||||
|
||||
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
|
||||
amax = torch.max(torch.abs(x))
|
||||
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
|
||||
inverse_scale = scale.reciprocal()
|
||||
return x_fp8, inverse_scale
|
||||
|
||||
|
||||
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
||||
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
|
||||
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
|
||||
inverse_scale = scale.reciprocal()
|
||||
return x_fp8, inverse_scale
|
||||
|
||||
|
||||
def _fix_fp8_dtype_for_rocm(
|
||||
dtype: Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]], device
|
||||
|
||||
@ -284,6 +284,7 @@ class CUTLASSTemplate(CUDATemplate):
|
||||
torch.uint8: "uint8_t",
|
||||
torch.bool: "bool",
|
||||
torch.bfloat16: "cutlass::bfloat16_t",
|
||||
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
|
||||
}
|
||||
|
||||
_DTYPE_TO_CUTLASS_SPARSE_META = {
|
||||
|
||||
@ -20,6 +20,20 @@ from ...virtualized import V
|
||||
_ACCUMULATOR_ARG_NAME = "accum"
|
||||
|
||||
|
||||
def scaled_mm_evt(
|
||||
scale_A_name: str, scale_B_name: str, output_name: str
|
||||
) -> tuple[list[str], dict[str, Any], str]:
|
||||
evt_read_names = [scale_A_name, scale_B_name]
|
||||
var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]}
|
||||
var_name_to_buffer_name["D"] = output_name
|
||||
var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name
|
||||
evt_py_code = f"def fn(accum, {scale_A_name}, {scale_B_name}):{linesep}\
|
||||
D = accum * {scale_A_name} * {scale_B_name}{linesep}\
|
||||
return D{linesep}"
|
||||
|
||||
return evt_read_names, var_name_to_buffer_name, evt_py_code
|
||||
|
||||
|
||||
class CutlassEVTOpsMixIn:
|
||||
@staticmethod
|
||||
def _infix_bin_op(op: str, a: str, b: str) -> str:
|
||||
|
||||
@ -314,6 +314,7 @@ DTYPE_TO_CUTLASS_TYPE = {
|
||||
**DTYPE_TO_CPP,
|
||||
torch.float16: "__half",
|
||||
torch.bfloat16: "__nv_bfloat16",
|
||||
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
|
||||
}
|
||||
|
||||
|
||||
@ -359,6 +360,8 @@ def dtype_match(
|
||||
return cutlass_dtype == cutlass_library.library.DataType.u8
|
||||
elif torch_dtype == torch.int32:
|
||||
return cutlass_dtype == cutlass_library.library.DataType.s32
|
||||
elif torch_dtype == torch.float8_e4m3fn:
|
||||
return cutlass_dtype == cutlass_library.library.DataType.e4m3
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -389,7 +392,7 @@ def get_accumulator_dtype(
|
||||
]:
|
||||
torch_dtype = dtype0
|
||||
|
||||
if torch_dtype in (torch.float16, torch.bfloat16, torch.float):
|
||||
if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn):
|
||||
return torch.float
|
||||
if torch_dtype == torch.int8:
|
||||
return torch.int32
|
||||
@ -407,7 +410,7 @@ def get_alignments(torch_dtype: torch.dtype) -> list[int]:
|
||||
return [8, 4, 2, 1]
|
||||
elif torch_dtype == torch.float:
|
||||
return [4, 2, 1]
|
||||
elif torch_dtype in (torch.uint8, torch.int8):
|
||||
elif torch_dtype in (torch.uint8, torch.int8, torch.float8_e4m3fn):
|
||||
return [16, 8, 4, 2]
|
||||
elif torch_dtype == torch.int32:
|
||||
return [4, 2, 1]
|
||||
|
||||
@ -31,7 +31,7 @@ from . import cutlass_utils
|
||||
from .cuda_kernel import CUDATemplateKernel
|
||||
from .cuda_template import CUTLASSTemplate
|
||||
from .cutlass_presets import gen_cutlass_presets
|
||||
from .cutlass_python_evt import CutlassEVTCodegen
|
||||
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
|
||||
from .cutlass_utils import torch_dtype_to_cutlass_type
|
||||
|
||||
|
||||
@ -435,7 +435,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
)
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
assert len(input_nodes) == 2 or len(input_nodes) == 3
|
||||
assert len(input_nodes) == 2 or len(input_nodes) == 3 or len(input_nodes) == 4
|
||||
assert self._are_inputs_layout_compatible(
|
||||
[node.get_layout() for node in input_nodes]
|
||||
)
|
||||
@ -1064,28 +1064,54 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
op = self.swap_XW(op)
|
||||
should_swap_xw = True
|
||||
|
||||
if epilogue_nodes:
|
||||
(
|
||||
evt_read_names,
|
||||
evt_write_names,
|
||||
var_name_to_buffer_name,
|
||||
evt_py_code,
|
||||
) = CutlassEVTCodegen.ir_to_evt_python_code(
|
||||
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
|
||||
)
|
||||
read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names)
|
||||
write_names = OrderedSet(evt_write_names)
|
||||
assert write_names, "There should be at least one write"
|
||||
D_output_name = var_name_to_buffer_name["D"]
|
||||
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
|
||||
D_output_buffer = name_to_buffer[D_output_name]
|
||||
Y = D_output_buffer # type: ignore[assignment]
|
||||
# Interestingly, I don't think the rest of the layout matters here since we
|
||||
# use the properties of the Y buffer to fill in D's properties in the epilogue
|
||||
# args. This is needed though because it defines types expected in the epilogue args.
|
||||
op.D.element = cutlass_utils.torch_dtype_to_cutlass_type(
|
||||
D_output_buffer.get_layout().dtype
|
||||
)
|
||||
is_scaled_mm = len(self.input_nodes) == 4
|
||||
if epilogue_nodes or is_scaled_mm:
|
||||
if epilogue_nodes:
|
||||
(
|
||||
evt_read_names,
|
||||
evt_write_names,
|
||||
var_name_to_buffer_name,
|
||||
evt_py_code,
|
||||
) = CutlassEVTCodegen.ir_to_evt_python_code(
|
||||
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
|
||||
)
|
||||
D_output_name = var_name_to_buffer_name["D"]
|
||||
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
|
||||
D_output_buffer = name_to_buffer[D_output_name]
|
||||
D_dtype = D_output_buffer.get_dtype()
|
||||
Y = D_output_buffer # type: ignore[assignment]
|
||||
# Interestingly, I don't think the rest of the layout matters here since we
|
||||
# use the properties of the Y buffer to fill in D's properties in the epilogue
|
||||
# args. This is needed though because it defines types expected in the epilogue args.
|
||||
op.D.element = cutlass_utils.torch_dtype_to_cutlass_type(
|
||||
D_output_buffer.get_layout().dtype
|
||||
)
|
||||
|
||||
read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names)
|
||||
write_names = OrderedSet(evt_write_names)
|
||||
assert write_names, "There should be at least one write"
|
||||
|
||||
input_names = list(read_names)
|
||||
output_names = list(write_names)
|
||||
epilogue_inputs = [name_to_buffer[name] for name in input_names]
|
||||
epilogue_outputs = [name_to_buffer[name] for name in output_names]
|
||||
else: # Scaled MM, we read the two scale matrices and write a single output
|
||||
(
|
||||
evt_read_names,
|
||||
var_name_to_buffer_name,
|
||||
evt_py_code,
|
||||
) = scaled_mm_evt(
|
||||
self.input_nodes[2].get_name(),
|
||||
self.input_nodes[3].get_name(),
|
||||
Y.get_name(),
|
||||
)
|
||||
|
||||
input_names = list(evt_read_names)
|
||||
output_names = [] # We only need Y
|
||||
D_dtype = Y.get_layout().dtype
|
||||
epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]]
|
||||
epilogue_outputs = []
|
||||
|
||||
acc_dtype = cutlass_utils.get_accumulator_dtype(
|
||||
[X.get_dtype(), W.get_dtype()]
|
||||
)
|
||||
@ -1095,24 +1121,20 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
op,
|
||||
evt_py_code,
|
||||
var_name_to_buffer_name,
|
||||
D_output_buffer.get_dtype(),
|
||||
D_dtype,
|
||||
acc_dtype,
|
||||
)
|
||||
|
||||
input_names = list(read_names)
|
||||
output_names = list(write_names)
|
||||
epilogue_inputs = [name_to_buffer[name] for name in input_names]
|
||||
epilogue_outputs = [name_to_buffer[name] for name in output_names]
|
||||
inputs = [
|
||||
X,
|
||||
W,
|
||||
Bias,
|
||||
Y,
|
||||
*epilogue_inputs, # type: ignore[list-item]
|
||||
Y,
|
||||
*extra_inputs,
|
||||
]
|
||||
names_str = ",".join(
|
||||
["X", "W", "Bias", "Y", *input_names, *output_names, *extra_names]
|
||||
["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names]
|
||||
)
|
||||
else:
|
||||
evt_name = None
|
||||
@ -1286,7 +1308,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
Returns:
|
||||
bool: True if layouts are GEMM compatible, otherwise False.
|
||||
"""
|
||||
assert len(layouts) == 2 or len(layouts) == 3
|
||||
assert len(layouts) == 2 or len(layouts) == 3 or len(layouts) == 4
|
||||
# Check if A and B are compatible
|
||||
A_layout, B_layout = layouts[:2]
|
||||
if len(A_layout.size) < 1:
|
||||
@ -1354,6 +1376,8 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
|
||||
|
||||
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
|
||||
# handle the fake output buffer during lowering
|
||||
name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment]
|
||||
|
||||
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
|
||||
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
|
||||
@ -1395,7 +1419,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
) -> bool:
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
|
||||
has_bias = len(self.input_nodes) == 3 and self.input_nodes[2] is not None
|
||||
if has_bias:
|
||||
Bias = self.input_nodes[2]
|
||||
# bias dtype
|
||||
@ -1464,7 +1488,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
self,
|
||||
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
|
||||
) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]:
|
||||
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
|
||||
Bias = None if len(self.input_nodes) in (2, 4) else self.input_nodes[2]
|
||||
inputs: list[Optional[Buffer]] = []
|
||||
names: list[str] = []
|
||||
return (Bias, inputs, names)
|
||||
|
||||
@ -1187,6 +1187,14 @@ def tuned_scaled_mm(
|
||||
epilogue_fn=scale_mm_epilogue(),
|
||||
)
|
||||
|
||||
if is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
if use_fast_accum:
|
||||
log.warning(
|
||||
"use_fast_accum=True is not supported by cutlass template, skipping cutlass choices"
|
||||
)
|
||||
else:
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, input_nodes) # type: ignore[arg-type]
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)
|
||||
|
||||
|
||||
@ -228,3 +228,65 @@ def clone_preserve_strides_offset(x, device=None):
|
||||
buffer = buffer.to(device, copy=True)
|
||||
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
|
||||
return out
|
||||
|
||||
# define the e4m3/e5m2 constants
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||
|
||||
FP16_MAX_POS: float = torch.finfo(torch.float16).max
|
||||
EPS: float = 1e-12
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
|
||||
# The default behavior in PyTorch for casting to `float8_e4m3fn`
|
||||
# and `e5m2` is to not saturate. In this context, we should saturate.
|
||||
# A common case where we want to saturate is when the history of a
|
||||
# tensor has a maximum value of `amax1`, and the current amax value
|
||||
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
|
||||
# scaling.
|
||||
if float8_dtype == torch.float8_e4m3fn:
|
||||
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
elif float8_dtype == torch.float8_e5m2:
|
||||
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
||||
elif float8_dtype == torch.float8_e4m3fnuz:
|
||||
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
|
||||
elif float8_dtype == torch.float8_e5m2fnuz:
|
||||
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
|
||||
else:
|
||||
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
return x.to(float8_dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def _amax_to_scale(
|
||||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
# To make scale dtype to be fp32 for accuracy
|
||||
amax = amax.float()
|
||||
if float8_dtype == torch.float8_e4m3fn:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else: # e5m2
|
||||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
|
||||
# Ensure that the scale is representable in float16,
|
||||
# this helps when amax is small. We are assuming that we don't need
|
||||
# to care about this for float32/bfloat16.
|
||||
if orig_dtype is torch.float16:
|
||||
res = torch.clamp(res, max=FP16_MAX_POS)
|
||||
return res
|
||||
|
||||
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
|
||||
amax = torch.max(torch.abs(x))
|
||||
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
|
||||
inverse_scale = scale.reciprocal()
|
||||
return x_fp8, inverse_scale
|
||||
|
||||
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
||||
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
|
||||
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
|
||||
inverse_scale = scale.reciprocal()
|
||||
return x_fp8, inverse_scale
|
||||
|
||||
Reference in New Issue
Block a user