[NVIDIA] Support SiluMul + NVFP4 quant fusion (#23671)
Signed-off-by: jindih <jindih@nvidia.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: jindih <jindih@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedic <lgovedic@redhat.com>
This commit is contained in:
@ -668,6 +668,7 @@ steps:
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
@ -677,6 +678,7 @@ steps:
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
@ -541,6 +541,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -559,6 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
|
@ -19,6 +19,13 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
|
||||
|
||||
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
||||
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
||||
// such that the correct kernel is dispatched.
|
||||
@ -45,6 +52,15 @@
|
||||
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||
|
||||
|
@ -130,6 +130,14 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
||||
torch::Tensor& output_block_scale,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& input_global_scale);
|
||||
#endif
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
368
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
Normal file
368
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
Normal file
@ -0,0 +1,368 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = c10::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<c10::Half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = c10::BFloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<c10::BFloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
PackedVec<Type> result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
if constexpr (std::is_same_v<Type, c10::Half>) {
|
||||
half2 val(0.5f, 0.5f);
|
||||
half2 t0 = __hmul2(vec.elts[i], val);
|
||||
half2 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
half2 t2 = __hmul2(vec.elts[i], t1);
|
||||
result.elts[i] = __hmul2(t2, vec2.elts[i]);
|
||||
} else {
|
||||
__nv_bfloat162 val(0.5f, 0.5f);
|
||||
__nv_bfloat162 t0 = __hmul2(vec.elts[i], val);
|
||||
__nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
__nv_bfloat162 t2 = __hmul2(vec.elts[i], t1);
|
||||
result.elts[i] = __hmul2(t2, vec2.elts[i]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2,
|
||||
float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
PackedVec<Type> out_silu = compute_silu(vec, vec2);
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(out_silu.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, c10::Half>) {
|
||||
fp2Vals[i] = __half22float2(out_silu.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4(
|
||||
#else
|
||||
silu_and_cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||
uint32_t* out, uint32_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||
|
||||
// Input tensor row/col loops.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
colIdx += blockDim.x) {
|
||||
int64_t inOffset =
|
||||
rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) +
|
||||
numCols / CVT_FP4_ELTS_PER_THREAD + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
PackedVec in_vec2 = reinterpret_cast<PackedVec const*>(in)[inOffset2];
|
||||
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx, colIdx, numCols, SFout);
|
||||
|
||||
out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(
|
||||
in_vec, in_vec2, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
|
||||
torch::Tensor& output_sf,
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& input_sf) {
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1) / 2;
|
||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "act_and_mul_quant_kernel", [&] {
|
||||
auto input_ptr = reinterpret_cast<scalar_t const*>(input.data_ptr());
|
||||
VLLM_DISPATCH_BYTE_TYPES(
|
||||
output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type",
|
||||
[&] {
|
||||
vllm::silu_and_cvt_fp16_to_fp4<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
});
|
||||
}
|
@ -115,6 +115,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||
"Tensor input, Tensor input_global_scale) -> ()");
|
||||
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
|
||||
#endif
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
|
@ -4,32 +4,41 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
|
||||
# yapf: enable
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def is_nvfp4_supported():
|
||||
return current_platform.has_device_capability(100)
|
||||
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = (torch.rand(
|
||||
hidden_size,
|
||||
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
|
||||
@ -45,14 +54,56 @@ class TestModel(torch.nn.Module):
|
||||
input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
def ops_in_model_after(self):
|
||||
return [FUSED_OPS[kFp8StaticTensorSym]]
|
||||
|
||||
|
||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.w = torch.randint(256, (hidden_size, hidden_size // 2),
|
||||
dtype=FP4_DTYPE)
|
||||
self.wscale = torch.randn(hidden_size,
|
||||
hidden_size // 16).to(dtype=FP8_DTYPE)
|
||||
self.wscale2 = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale)
|
||||
out = cutlass_scaled_fp4_mm(a=y_quant,
|
||||
b=self.w,
|
||||
block_scale_a=y_block_scale,
|
||||
block_scale_b=self.wscale,
|
||||
alpha=self.scale * self.wscale2,
|
||||
out_dtype=y.dtype)
|
||||
return out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [FUSED_OPS[kNvfp4Quant]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [64])
|
||||
@pytest.mark.parametrize("hidden_size", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
|
||||
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
||||
force_fp8_e4m3fnuz):
|
||||
if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz:
|
||||
pytest.skip("Duplicate tests for NVFP4")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
@ -63,7 +114,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = TestModel(hidden_size, force_fp8_e4m3fnuz)
|
||||
model = model_class(hidden_size=hidden_size,
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
@ -80,17 +132,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend.graph_pre_pass.nodes
|
||||
post_nodes = backend.graph_post_pass.nodes
|
||||
# In pre-nodes, quant op should be present and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
|
||||
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
|
||||
|
||||
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
||||
assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
|
||||
find_auto_fn(pre_nodes, fp8_quant)
|
||||
|
||||
# In post-nodes, fused kernels should be present and fp8 quant should not
|
||||
find_auto_fn(post_nodes, silu_and_mul_quant)
|
||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
||||
# In post-nodes, fused kernels should be present and quant op should not
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
126
tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
Normal file
126
tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
Normal file
@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
global_scale: torch.Tensor,
|
||||
ref_output_scale: torch.Tensor) -> torch.Tensor:
|
||||
silu_and_mul_out = silu_and_mul.forward_native(x)
|
||||
assert not current_platform.is_rocm()
|
||||
assert silu_and_mul_out.ndim >= 1, (
|
||||
f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.')
|
||||
other_dims = 1 if silu_and_mul_out.ndim == 1 else -1
|
||||
silu_and_mul_out = silu_and_mul_out.reshape(other_dims,
|
||||
silu_and_mul_out.shape[-1])
|
||||
m, n = silu_and_mul_out.shape
|
||||
device = silu_and_mul_out.device
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
out = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
output_scale = ref_output_scale
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale,
|
||||
global_scale)
|
||||
|
||||
return out, output_scale
|
||||
|
||||
|
||||
def ops_impl(x: torch.Tensor, global_scale: torch.Tensor,
|
||||
ref_output_scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 4)
|
||||
output_scale = ref_output_scale
|
||||
out = torch.empty(out_shape, dtype=torch.uint8, device=x.device)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale)
|
||||
return out, output_scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
m, n = shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
|
||||
block_size = 16
|
||||
|
||||
assert n % block_size == 0, (
|
||||
f'last dim has to be multiple of 16, but got {n}.')
|
||||
assert x.dtype in (torch.float16, torch.bfloat16), (
|
||||
f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.')
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(x.shape[0], 128)
|
||||
scale_n = x.shape[1] // (2 * block_size)
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||
device=x.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
layer = SiluAndMul()
|
||||
|
||||
ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale)
|
||||
|
||||
fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale)
|
||||
|
||||
assert ref_out.dtype == torch.uint8
|
||||
assert fusion_out.dtype == torch.uint8
|
||||
assert ref_out.shape == fusion_out.shape
|
||||
|
||||
assert ref_out_scale.dtype == torch.int32
|
||||
assert fusion_out_scale.dtype == torch.int32
|
||||
assert ref_out_scale.shape == fusion_out_scale.shape
|
||||
|
||||
# Allow up to 2% of mismatched values since BF16 has accuracy issues.
|
||||
mis_threshold = 0.02
|
||||
atol = 0.4
|
||||
rtol = 0.4
|
||||
ref_logits = ref_out[-1]
|
||||
fusion_logits = fusion_out[-1]
|
||||
|
||||
mis_count = torch.sum(
|
||||
torch.abs(fusion_logits - ref_logits) > (atol +
|
||||
rtol * torch.abs(ref_logits)))
|
||||
mis_ratio = mis_count / fusion_logits.numel()
|
||||
|
||||
assert mis_ratio < mis_threshold, \
|
||||
f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}"
|
||||
|
||||
torch.testing.assert_close(ref_out_scale, fusion_out_scale)
|
||||
|
||||
opcheck(torch.ops._C.silu_and_mul_nvfp4_quant,
|
||||
(fusion_out, fusion_out_scale, x, global_scale))
|
@ -1,55 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
||||
register_replacement)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
def silu_mul_pattern_static(result: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"silu_and_mul_nvfp4_quant"):
|
||||
FUSED_OPS[
|
||||
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
):
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, \
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self, symmetric: bool = True):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(quant_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
|
||||
def silu_mul_replacement_static(result: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor,
|
||||
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
scale=scale)
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 4), # result
|
||||
empty_bf16(5, 4), # result_silu_mul
|
||||
empty_bf16(5, 4), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
def empty_fp8(*args, **kwargs):
|
||||
fp8 = current_platform.fp8_dtype()
|
||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(kNvfp4Quant)
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=result,
|
||||
input=at1[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=scale)
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale)
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 32), # result
|
||||
empty_i32(128, 4), # output_scale
|
||||
empty_bf16(5, 64), # result_silu_mul
|
||||
empty_bf16(5, 64), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmInductorPass):
|
||||
@ -69,15 +168,11 @@ class ActivationQuantFusionPass(VllmInductorPass):
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass")
|
||||
|
||||
inputs = [
|
||||
empty_fp8(5, 4), # Quant output
|
||||
empty_bf16(5, 4), # Silu_and_mul output
|
||||
empty_bf16(5, 4), # Input
|
||||
empty_fp32(1, 1) # Scale
|
||||
]
|
||||
register_replacement(silu_mul_pattern_static,
|
||||
silu_mul_replacement_static, inputs, fwd_only,
|
||||
self.patterns)
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
@ -89,3 +184,8 @@ class ActivationQuantFusionPass(VllmInductorPass):
|
||||
|
||||
self.dump_graph(graph, "after_act_quant_fusion")
|
||||
self.end_and_log()
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern)
|
||||
|
@ -97,6 +97,13 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input', 'scale'))
|
||||
elif at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
|
||||
mutated_args = {1: 'result', 2: 'result_block_scale'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'result_block_scale',
|
||||
'input', 'input_global_scale'))
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
|
@ -885,6 +885,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
# Calculate `1 / input_scale` so that we don't need to do so at runtime
|
||||
layer.input_scale_inv = Parameter(
|
||||
(1 / layer.input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
# Swizzle the weight blockscale.
|
||||
# contracting dimension is input dimension
|
||||
# block_size = 16;
|
||||
@ -941,8 +945,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
s_quant = 1 / layer.input_scale
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
|
||||
|
||||
# validate dtypes of quantized input, input block scale,
|
||||
# weight and weight_blockscale
|
||||
|
Reference in New Issue
Block a user