diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 903d84270..1a887b50e 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -24,7 +24,6 @@ at::Tensor quantize(torch::Tensor& out, torch::Tensor& val, - torch::Tensor& scale, int group_size, int stochastic_rounding, int q_bits, @@ -60,7 +59,6 @@ at::Tensor quantize(torch::Tensor& out, void dequantize(torch::Tensor& val, torch::Tensor& val_q, - torch::Tensor& scale, int group_size, int q_mantisa_bits, int q_exponent_bits) diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py index 746e217d4..086525cc6 100644 --- a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -39,25 +39,19 @@ def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, str weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( (pid_n * BLOCK_SIZE_N) // quantization_group_size) - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset + ((k * BLOCK_SIZE_K * stride_bk) // quantization_group_size)) # Dequantize weight (fp8 -> bf16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (weight & 0x80).to(tl.uint16) << 8 + w = w | ((weight & 0x7f).to(tl.uint16) << 4) w = (w + 0x3C00).to(tl.uint16) - w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + w = (w.to(tl.bfloat16, bitcast=True).to(tl.float32) * scale).to(tl.bfloat16) inp_data += BLOCK_SIZE_K * stride_ak weight_data += BLOCK_SIZE_K * stride_bk - weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K - weight = tl.load(weight_data, mask=weight_mask, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), - mask=weight_mask, - other=0.0) accumulator += tl.dot(inp, w) diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 69c21eaf6..47b3b08c7 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -79,27 +79,15 @@ class FP_Quantize(Quantizer): else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" - - # Adding (group_size - 1) is for padding - self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size - # group_size should be the minimal number between the defined group size and number of elements in tensor. - group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8 - # CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group - if self.cuda_impl: - group_size += 4 - # CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel. - self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device) - # CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done - # because they are of different types. - self.scale = torch.ones(self.num_groups, 1, device=input.device) - out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits, - q_mantisa_bits) + self.num_groups = input.numel() // self.group_size + self.input_q = torch.ones(self.num_groups, + int(self.group_size * q_bits) // 8 + 4, + dtype=torch.uint8, + device=input.device) + out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) if return_meta_tensor: - if self.cuda_impl: - data, self.scale = out.split(group_size, dim=-1) - data = data.contiguous().reshape(input.shape) - else: - data = out.contiguous().reshape(input.shape) + data, self.scale = out.split(self.group_size, dim=-1) + data = data.contiguous().reshape(input.shape) self.scale = self.scale.contiguous() del self.input_q del out @@ -111,9 +99,9 @@ class FP_Quantize(Quantizer): def to(self, *args, **kwargs): # Intermediate tensors may need to be moved to different devices - if hasattr(self, 'input_q') and self.input_q is not None: + if hasattr(self, 'input_q'): self.input_q = self.input_q.to(*args, **kwargs) - if hasattr(self, 'scale') and self.scale is not None: + if hasattr(self, 'scale'): self.scale = self.scale.to(*args, **kwargs) def get_scales(self): @@ -136,16 +124,11 @@ class FP_Quantize(Quantizer): assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None and self.cuda_impl: + if scale is not None: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() - elif scale is not None and not self.cuda_impl: - group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8 - input_q = input_q.reshape(-1, group_size) - - fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits, - q_bits - q_mantisa_bits - 1) + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out def selective_dequantize(self, @@ -174,11 +157,11 @@ class FP_Quantize(Quantizer): assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None and self.cuda_impl: + if scale is not None: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() - fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits, + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index df4d967ea..2b962ac2c 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -54,7 +54,7 @@ class FPQuantizerBuilder(CUDAOpBuilder): return False # triton 2.3.{0,1} and 3.0.0 are ok. - allowed_versions = ("2.3", "3.0") + allowed_versions = ("2.3", "3.0", "3.1", "3.2") if pkg_version: allowed = (pkg_version.parse(v) for v in allowed_versions) installed_triton = pkg_version.parse(triton.__version__)