mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix fp8 gemm (#7265)
This PR addresses this issue https://github.com/deepspeedai/DeepSpeed/issues/7236. I might have reverted some of the recent changes introduced in this [PR](https://github.com/deepspeedai/DeepSpeed/pull/6932), which was necessary to remove a misaligned address issue on the CUDA kernel. I will get back to this and try to make the necessary changes for the other pass. cc: @mrwyattii @jeffra --------- Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com> Co-authored-by: Reza Yazdani <rezay@microsoft.com> Co-authored-by: Jeff Rasley <jeffra45@gmail.com> Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
@ -24,7 +24,6 @@
|
||||
|
||||
at::Tensor quantize(torch::Tensor& out,
|
||||
torch::Tensor& val,
|
||||
torch::Tensor& scale,
|
||||
int group_size,
|
||||
int stochastic_rounding,
|
||||
int q_bits,
|
||||
@ -60,7 +59,6 @@ at::Tensor quantize(torch::Tensor& out,
|
||||
|
||||
void dequantize(torch::Tensor& val,
|
||||
torch::Tensor& val_q,
|
||||
torch::Tensor& scale,
|
||||
int group_size,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits)
|
||||
|
@ -39,25 +39,19 @@ def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, str
|
||||
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
|
||||
(pid_n * BLOCK_SIZE_N) // quantization_group_size)
|
||||
|
||||
weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
|
||||
scale = tl.load(scale_ptr + weight_ptrs_offset)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
|
||||
scale = tl.load(scale_ptr + weight_ptrs_offset + ((k * BLOCK_SIZE_K * stride_bk) // quantization_group_size))
|
||||
# Dequantize weight (fp8 -> bf16)
|
||||
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
|
||||
w = (weight & 0x80).to(tl.uint16) << 8
|
||||
w = w | ((weight & 0x7f).to(tl.uint16) << 4)
|
||||
w = (w + 0x3C00).to(tl.uint16)
|
||||
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)
|
||||
w = (w.to(tl.bfloat16, bitcast=True).to(tl.float32) * scale).to(tl.bfloat16)
|
||||
|
||||
inp_data += BLOCK_SIZE_K * stride_ak
|
||||
weight_data += BLOCK_SIZE_K * stride_bk
|
||||
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
|
||||
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
|
||||
scale = tl.load(scale_ptr + (weight_ptrs_offset +
|
||||
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
|
||||
mask=weight_mask,
|
||||
other=0.0)
|
||||
|
||||
accumulator += tl.dot(inp, w)
|
||||
|
||||
|
@ -79,27 +79,15 @@ class FP_Quantize(Quantizer):
|
||||
else:
|
||||
assert (0), \
|
||||
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
|
||||
|
||||
# Adding (group_size - 1) is for padding
|
||||
self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size
|
||||
# group_size should be the minimal number between the defined group size and number of elements in tensor.
|
||||
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8
|
||||
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group
|
||||
if self.cuda_impl:
|
||||
group_size += 4
|
||||
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
|
||||
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
|
||||
# CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done
|
||||
# because they are of different types.
|
||||
self.scale = torch.ones(self.num_groups, 1, device=input.device)
|
||||
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
|
||||
q_mantisa_bits)
|
||||
self.num_groups = input.numel() // self.group_size
|
||||
self.input_q = torch.ones(self.num_groups,
|
||||
int(self.group_size * q_bits) // 8 + 4,
|
||||
dtype=torch.uint8,
|
||||
device=input.device)
|
||||
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
|
||||
if return_meta_tensor:
|
||||
if self.cuda_impl:
|
||||
data, self.scale = out.split(group_size, dim=-1)
|
||||
data = data.contiguous().reshape(input.shape)
|
||||
else:
|
||||
data = out.contiguous().reshape(input.shape)
|
||||
data, self.scale = out.split(self.group_size, dim=-1)
|
||||
data = data.contiguous().reshape(input.shape)
|
||||
self.scale = self.scale.contiguous()
|
||||
del self.input_q
|
||||
del out
|
||||
@ -111,9 +99,9 @@ class FP_Quantize(Quantizer):
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# Intermediate tensors may need to be moved to different devices
|
||||
if hasattr(self, 'input_q') and self.input_q is not None:
|
||||
if hasattr(self, 'input_q'):
|
||||
self.input_q = self.input_q.to(*args, **kwargs)
|
||||
if hasattr(self, 'scale') and self.scale is not None:
|
||||
if hasattr(self, 'scale'):
|
||||
self.scale = self.scale.to(*args, **kwargs)
|
||||
|
||||
def get_scales(self):
|
||||
@ -136,16 +124,11 @@ class FP_Quantize(Quantizer):
|
||||
assert (0), \
|
||||
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
|
||||
|
||||
if scale is not None and self.cuda_impl:
|
||||
if scale is not None:
|
||||
assert input_q.numel() == fp_out.numel(), \
|
||||
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
|
||||
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
|
||||
elif scale is not None and not self.cuda_impl:
|
||||
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
|
||||
input_q = input_q.reshape(-1, group_size)
|
||||
|
||||
fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits,
|
||||
q_bits - q_mantisa_bits - 1)
|
||||
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
|
||||
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
|
||||
return fp_out
|
||||
|
||||
def selective_dequantize(self,
|
||||
@ -174,11 +157,11 @@ class FP_Quantize(Quantizer):
|
||||
assert (0), \
|
||||
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
|
||||
|
||||
if scale is not None and self.cuda_impl:
|
||||
if scale is not None:
|
||||
assert input_q.numel() == fp_out.numel(), \
|
||||
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
|
||||
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
|
||||
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
|
||||
|
||||
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits,
|
||||
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
|
||||
q_bits - q_mantisa_bits - 1)
|
||||
return fp_out
|
||||
|
@ -54,7 +54,7 @@ class FPQuantizerBuilder(CUDAOpBuilder):
|
||||
return False
|
||||
|
||||
# triton 2.3.{0,1} and 3.0.0 are ok.
|
||||
allowed_versions = ("2.3", "3.0")
|
||||
allowed_versions = ("2.3", "3.0", "3.1", "3.2")
|
||||
if pkg_version:
|
||||
allowed = (pkg_version.parse(v) for v in allowed_versions)
|
||||
installed_triton = pkg_version.parse(triton.__version__)
|
||||
|
Reference in New Issue
Block a user