diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index ed84d0dd7..f2d717695 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( UnquantizedEmbeddingMethod, VocabParallelEmbedding) +from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, @@ -250,6 +251,7 @@ class AscendLinearMethod(LinearMethodBase): **extra_weight_attrs, ) -> None: output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, @@ -262,7 +264,8 @@ class AscendLinearMethod(LinearMethodBase): pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): - param = torch.nn.Parameter(pertensor_param, requires_grad=False) + param = PerTensorScaleParameter(data=pertensor_param, + weight_loader=weight_loader) # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param)