Enable ZP Support for Machete (#20268)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-07-01 00:12:20 -07:00
committed by GitHub
parent 22e9d42040
commit 9909726d2a
3 changed files with 19 additions and 5 deletions

View File

@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
fn = lambda: ops.gptq_marlin_gemm(
a=bt.a,
c=None,
b_q_weight=w_q,
b_scales=w_s,
global_scale=None,
b_zeros=w_zp,
g_idx=g_idx,
perm=sort_indices,

View File

@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0
return group_size is None or group_size == -1 or shape[2] % group_size == 0
def machete_quantize_and_pack(atype: torch.dtype,

View File

@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel):
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by Machete"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel):
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel):
x.data = x.data.contiguous()
return x
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=1)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel):
output = ops.machete_mm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=None,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size)