mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature] Non-contiguous Support for FP8 Quantization (#21961)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -1,7 +1,8 @@
|
||||
#include "common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
@ -12,74 +13,127 @@
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
__global__ void scaled_fp8_quant_kernel_strided(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x; // one token per block
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Invert the scale so that we can use multiplications to avoid expensive
|
||||
// division.
|
||||
const float inverted_scale = 1.0f / (*scale);
|
||||
scaled_fp8_conversion_vec<scalar_t, true>(
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float inv_scale = 1.0f / (*scale);
|
||||
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
__global__ void segmented_max_reduction_strided(
|
||||
float* __restrict__ scale, const scalar_t* __restrict__ input,
|
||||
int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
|
||||
__shared__ float cache[256];
|
||||
const int tid = threadIdx.x;
|
||||
int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
||||
scalar_t const* __restrict__ token_input = &input[offset];
|
||||
fp8_type* __restrict__ token_output = &out[offset];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 32-byte and 16-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 16 == 0;
|
||||
|
||||
float absmax_val = 0.0f;
|
||||
if (can_vectorize) {
|
||||
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float const x = static_cast<float>(token_input[i]);
|
||||
absmax_val = fmaxf(absmax_val, fabsf(x));
|
||||
}
|
||||
// one block per token. Guard in case gridDim.x > num_tokens.
|
||||
if (token_idx >= num_tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scalar_t* row_ptr = input + token_idx * in_row_stride;
|
||||
|
||||
// each thread scans elements of the row in a strided fashion.
|
||||
float thread_max = 0.0f;
|
||||
for (int e = tid; e < hidden_size; e += blockDim.x) {
|
||||
float v = fabsf(static_cast<float>(row_ptr[e]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
|
||||
cache[tid] = thread_max;
|
||||
__syncthreads();
|
||||
|
||||
// parallel reduction to find row max.
|
||||
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
|
||||
if (tid < offset) {
|
||||
cache[tid] = fmaxf(cache[tid], cache[tid + offset]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// thread 0 updates global scale (per-tensor) atomically.
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided_dynamic(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float reciprocal_scale = 1.0f / (*scale);
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
reciprocal_scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input, const float* __restrict__ scale_ub,
|
||||
int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t in_offset = static_cast<int64_t>(token_idx) * in_row_stride;
|
||||
int64_t out_offset = static_cast<int64_t>(token_idx) * out_row_stride;
|
||||
const scalar_t* token_in = input + in_offset;
|
||||
fp8_type* token_out = out + out_offset;
|
||||
|
||||
// 1) per-token absmax
|
||||
float absmax_val = 0.f;
|
||||
vectorize_read_with_alignment<16>(
|
||||
token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(static_cast<float>(v)));
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
const float block_max =
|
||||
BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
if (scale_ub) {
|
||||
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max;
|
||||
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
||||
min_scaling_factor<fp8_type>::val());
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Note that we don't use inverted scales so we can match FBGemm impl.
|
||||
if (can_vectorize) {
|
||||
scaled_fp8_conversion_vec<scalar_t, false>(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
// 2) quantize
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<false, fp8_type>(static_cast<float>(src),
|
||||
token_scale);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
int const block_size = 256;
|
||||
int const num_tokens = input.numel() / input.size(-1);
|
||||
int const num_elems = input.numel();
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(block_size);
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
int const block_size = 256;
|
||||
int const num_tokens = input.numel() / input.size(-1);
|
||||
int const num_elems = input.numel();
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(block_size);
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// scale tensor should be initialised to <=0 before reduction
|
||||
AT_CUDA_CHECK(
|
||||
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size, in_row_stride,
|
||||
static_cast<int64_t>(num_tokens));
|
||||
|
||||
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
int const block_size = 256;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, block_size));
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, block_size));
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
|
||||
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size, in_row_stride, out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -55,111 +55,4 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
#endif
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[256];
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = fmaxf(tmp, fabsf(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
constexpr size_t VEC_SIZE = 16;
|
||||
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
||||
|
||||
// num_elems / VEC_SIZE (which is 16)
|
||||
int64_t const num_vec_elems = num_elems >> 4;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
scalarxN_t in_vec = vectorized_in[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
||||
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
||||
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
constexpr size_t VEC_SIZE = 16;
|
||||
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
||||
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
|
||||
|
||||
// num_elems / VEC_SIZE (which is 16)
|
||||
int64_t const num_vec_elems = num_elems >> 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
scalarxN_t in_vec = vectorized_in[i];
|
||||
float8xN_t out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.val[j]), scale);
|
||||
}
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
||||
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -194,3 +194,36 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
ref_y,
|
||||
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
||||
dtype))
|
||||
|
||||
# non-contiguous input with padding
|
||||
m, n, padded_stride = 975, 512, 576
|
||||
padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") *
|
||||
13).to(dtype)
|
||||
x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1)
|
||||
|
||||
assert not x_nc.is_contiguous()
|
||||
assert x_nc.stride(0) == padded_stride
|
||||
|
||||
# dynamic quantization
|
||||
ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None)
|
||||
ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype)
|
||||
|
||||
# reference dynamic quantization
|
||||
y_nc = quantize_ref(x_nc, inv_scale_nc)
|
||||
torch.testing.assert_close(
|
||||
ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype))
|
||||
|
||||
# static quantization
|
||||
y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc)
|
||||
torch.testing.assert_close(
|
||||
ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype))
|
||||
|
||||
# padding after non-contiguous input quantization
|
||||
y_nc_pad, _ = ops.scaled_fp8_quant(x_nc,
|
||||
inv_scale_nc,
|
||||
num_token_padding=m + 10)
|
||||
assert y_nc_pad.shape[0] == m + 10
|
||||
torch.testing.assert_close(
|
||||
ref_y_nc,
|
||||
per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]),
|
||||
inv_scale_nc, dtype))
|
||||
|
@ -1279,14 +1279,13 @@ def scaled_fp8_quant(
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||
output, input.contiguous(), scale, scale_ub)
|
||||
output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input.contiguous(),
|
||||
scale)
|
||||
scale = torch.empty(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
assert scale.numel() == 1, f"{scale.shape}"
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input.contiguous(), scale)
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
Reference in New Issue
Block a user