mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][Misc] Fix silu_and_mul_nvfp4_quant issue and extract common utils for nvfp4 kernel source files (#23727)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@ -666,7 +666,7 @@ steps:
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
# - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
@ -676,7 +676,7 @@ steps:
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
# - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
@ -52,15 +52,6 @@
|
||||
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||
|
||||
|
@ -130,8 +130,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
#ifndef USE_ROCM
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
||||
torch::Tensor& output_block_scale,
|
||||
torch::Tensor& input,
|
||||
|
@ -26,164 +26,17 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = c10::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<c10::Half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = c10::BFloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<c10::BFloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
PackedVec<Type> result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
if constexpr (std::is_same_v<Type, c10::Half>) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
half2 val(0.5f, 0.5f);
|
||||
half2 t0 = __hmul2(vec.elts[i], val);
|
||||
half2 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
@ -206,13 +59,12 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2,
|
||||
float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
PackedVec<Type> out_silu = compute_silu(vec, vec2);
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(out_silu.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
|
||||
}
|
||||
@ -259,9 +111,9 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, c10::Half>) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(out_silu.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
|
||||
@ -275,22 +127,14 @@ __device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4(
|
||||
#else
|
||||
silu_and_cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||
uint32_t* out, uint32_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out,
|
||||
uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
@ -328,22 +172,25 @@ silu_and_cvt_fp16_to_fp4(
|
||||
in_vec, in_vec2, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& input_sf) {
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& input_sf) {
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1) / 2;
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
@ -352,17 +199,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "act_and_mul_quant_kernel", [&] {
|
||||
auto input_ptr = reinterpret_cast<scalar_t const*>(input.data_ptr());
|
||||
VLLM_DISPATCH_BYTE_TYPES(
|
||||
output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type",
|
||||
[&] {
|
||||
vllm::silu_and_cvt_fp16_to_fp4<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
}
|
||||
|
@ -1,3 +1,19 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
|
@ -1,247 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__global__ void __launch_bounds__(512, 4)
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts,
|
||||
bool low_latency) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
@ -299,8 +94,8 @@ cvt_fp16_to_fp4(
|
||||
&input_offset_by_experts[chunk_start + 12]));
|
||||
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
|
||||
|
||||
// Check against the 16 loaded offsets
|
||||
#pragma unroll
|
||||
// Check against the 16 loaded offsets
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
|
||||
rowIdx_in_expert = rowIdx - local_offsets[i];
|
||||
@ -330,21 +125,15 @@ cvt_fp16_to_fp4(
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(1024, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
@ -425,7 +214,6 @@ cvt_fp16_to_fp4(
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
quant_impl<half>(output.data_ptr(), output_scale.data_ptr(),
|
||||
input.data_ptr(), input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k,
|
||||
n_experts, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(),
|
||||
input.data_ptr(), input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk,
|
||||
k, n_experts, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type>(
|
||||
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
||||
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& input_sf);
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
@ -54,3 +62,13 @@ void scaled_fp4_experts_quant(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
|
||||
torch::Tensor& input, torch::Tensor& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
||||
}
|
||||
|
@ -23,245 +23,18 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||
uint32_t* out, uint32_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__global__ void __launch_bounds__(512, 4)
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
@ -293,7 +66,6 @@ cvt_fp16_to_fp4(
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
|
||||
int multiProcessorCount,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& output_sf,
|
||||
@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
int32_t n = input.size(1);
|
||||
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
// We don't support e8m0 scales at this moment.
|
||||
bool useUE8M0 = false;
|
||||
|
||||
switch (input.scalar_type()) {
|
||||
case torch::kHalf: {
|
||||
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
|
||||
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
|
||||
useUE8M0, multiProcessorCount, stream);
|
||||
break;
|
||||
}
|
||||
case torch::kBFloat16: {
|
||||
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
|
||||
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
|
||||
useUE8M0, multiProcessorCount, stream);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::cerr << "Observing: " << input.scalar_type()
|
||||
<< " for the input datatype which is invalid";
|
||||
throw std::runtime_error(
|
||||
"Unsupported input data type for quantize_to_fp4.");
|
||||
}
|
||||
}
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
|
||||
sf_out, useUE8M0, multiProcessorCount, stream);
|
||||
});
|
||||
}
|
||||
|
251
csrc/quantization/fp4/nvfp4_utils.cuh
Normal file
251
csrc/quantization/fp4/nvfp4_utils.cuh
Normal file
@ -0,0 +1,251 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Convert PyTorch cpp type to CUDA type
|
||||
template <typename T>
|
||||
struct CUDATypeConverter {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDATypeConverter<at::Half> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDATypeConverter<at::BFloat16> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||
return val;
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||
return val;
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -115,8 +115,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||
"Tensor input, Tensor input_global_scale) -> ()");
|
||||
|
@ -8,8 +8,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not (current_platform.has_device_capability(100)
|
||||
and hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")):
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
|
Reference in New Issue
Block a user