Fix fp8 gemm (#7265)

This PR addresses this issue
https://github.com/deepspeedai/DeepSpeed/issues/7236.
I might have reverted some of the recent changes introduced in this
[PR](https://github.com/deepspeedai/DeepSpeed/pull/6932), which was
necessary to remove a misaligned address issue on the CUDA kernel. I
will get back to this and try to make the necessary changes for the
other pass.

cc: @mrwyattii @jeffra

---------

Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com>
Co-authored-by: Reza Yazdani <rezay@microsoft.com>
Co-authored-by: Jeff Rasley <jeffra45@gmail.com>
Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Reza Yazdani
2025-05-08 15:21:52 -07:00
committed by GitHub
parent e1ba9e614f
commit 069ec31c59
4 changed files with 22 additions and 47 deletions

View File

@ -24,7 +24,6 @@
at::Tensor quantize(torch::Tensor& out, at::Tensor quantize(torch::Tensor& out,
torch::Tensor& val, torch::Tensor& val,
torch::Tensor& scale,
int group_size, int group_size,
int stochastic_rounding, int stochastic_rounding,
int q_bits, int q_bits,
@ -60,7 +59,6 @@ at::Tensor quantize(torch::Tensor& out,
void dequantize(torch::Tensor& val, void dequantize(torch::Tensor& val,
torch::Tensor& val_q, torch::Tensor& val_q,
torch::Tensor& scale,
int group_size, int group_size,
int q_mantisa_bits, int q_mantisa_bits,
int q_exponent_bits) int q_exponent_bits)

View File

@ -39,25 +39,19 @@ def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, str
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size) (pid_n * BLOCK_SIZE_N) // quantization_group_size)
weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset + ((k * BLOCK_SIZE_K * stride_bk) // quantization_group_size))
# Dequantize weight (fp8 -> bf16) # Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) w = (weight & 0x80).to(tl.uint16) << 8
w = w | ((weight & 0x7f).to(tl.uint16) << 4)
w = (w + 0x3C00).to(tl.uint16) w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) w = (w.to(tl.bfloat16, bitcast=True).to(tl.float32) * scale).to(tl.bfloat16)
inp_data += BLOCK_SIZE_K * stride_ak inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk weight_data += BLOCK_SIZE_K * stride_bk
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
mask=weight_mask,
other=0.0)
accumulator += tl.dot(inp, w) accumulator += tl.dot(inp, w)

View File

@ -79,27 +79,15 @@ class FP_Quantize(Quantizer):
else: else:
assert (0), \ assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
self.num_groups = input.numel() // self.group_size
# Adding (group_size - 1) is for padding self.input_q = torch.ones(self.num_groups,
self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size int(self.group_size * q_bits) // 8 + 4,
# group_size should be the minimal number between the defined group size and number of elements in tensor. dtype=torch.uint8,
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8 device=input.device)
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
if self.cuda_impl:
group_size += 4
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
# CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done
# because they are of different types.
self.scale = torch.ones(self.num_groups, 1, device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
q_mantisa_bits)
if return_meta_tensor: if return_meta_tensor:
if self.cuda_impl: data, self.scale = out.split(self.group_size, dim=-1)
data, self.scale = out.split(group_size, dim=-1)
data = data.contiguous().reshape(input.shape) data = data.contiguous().reshape(input.shape)
else:
data = out.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous() self.scale = self.scale.contiguous()
del self.input_q del self.input_q
del out del out
@ -111,9 +99,9 @@ class FP_Quantize(Quantizer):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices # Intermediate tensors may need to be moved to different devices
if hasattr(self, 'input_q') and self.input_q is not None: if hasattr(self, 'input_q'):
self.input_q = self.input_q.to(*args, **kwargs) self.input_q = self.input_q.to(*args, **kwargs)
if hasattr(self, 'scale') and self.scale is not None: if hasattr(self, 'scale'):
self.scale = self.scale.to(*args, **kwargs) self.scale = self.scale.to(*args, **kwargs)
def get_scales(self): def get_scales(self):
@ -136,16 +124,11 @@ class FP_Quantize(Quantizer):
assert (0), \ assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
if scale is not None and self.cuda_impl: if scale is not None:
assert input_q.numel() == fp_out.numel(), \ assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
elif scale is not None and not self.cuda_impl: fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
input_q = input_q.reshape(-1, group_size)
fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out return fp_out
def selective_dequantize(self, def selective_dequantize(self,
@ -174,11 +157,11 @@ class FP_Quantize(Quantizer):
assert (0), \ assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
if scale is not None and self.cuda_impl: if scale is not None:
assert input_q.numel() == fp_out.numel(), \ assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits, fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1) q_bits - q_mantisa_bits - 1)
return fp_out return fp_out

View File

@ -54,7 +54,7 @@ class FPQuantizerBuilder(CUDAOpBuilder):
return False return False
# triton 2.3.{0,1} and 3.0.0 are ok. # triton 2.3.{0,1} and 3.0.0 are ok.
allowed_versions = ("2.3", "3.0") allowed_versions = ("2.3", "3.0", "3.1", "3.2")
if pkg_version: if pkg_version:
allowed = (pkg_version.parse(v) for v in allowed_versions) allowed = (pkg_version.parse(v) for v in allowed_versions)
installed_triton = pkg_version.parse(triton.__version__) installed_triton = pkg_version.parse(triton.__version__)