Files
DeepSpeed/deepspeed/compression/helper.py
Nir Sonnenschein 1a8ad24f0d fix issues raised by Coverity scans (#7431)
This commit combines fixes for 37 potential code issues found in
Coverity scans.
the issues include but are not limited to potential access to
uninitialized variables, dead and redundant code.
We understand that reviewing such a commit can be difficult and will be
happy to help with any questions or changes required.

---------

Signed-off-by: Nir Sonnenschein <nsonnenschein@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
2025-08-02 12:16:10 -04:00

323 lines
14 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
from .constants import *
from deepspeed.utils import logger
try:
from neural_compressor.compression import pruner as nc_pruner
except ImportError as e:
nc_pruner = None
def recursive_getattr(model, module_name):
"""
Recursively get the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to get the attribute from.
module_name (`str`)
The name of the module to get the attribute from.
"""
split_list = module_name.split('.')
output = model
for name in split_list:
output = getattr(output, name)
return output
def recursive_setattr(model, module_name, module):
"""
Recursively set the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to set the attribute in.
module_name (`str`)
The name of the module to set the attribute in.
module (`torch.nn.Module`)
The module to set the attribute to.
"""
split_list = module_name.split('.')
output = model
for name in split_list[:-1]:
output = getattr(output, name)
output.__setattr__(split_list[-1], module)
def module_replacement(model, module_name, compression_technique=None, mpu=None):
"""
Replace a module with a new module.
Args:
model (`torch.nn.Module`)
The model to replace the module in.
module_name (`str`)
The name of the module to replace.
compression_technique (`str`)
The compression technique to use for the new module.
"""
# Get the old module
old_module = recursive_getattr(model, module_name)
need_bias = False
if hasattr(old_module, 'bias') and old_module.bias is not None:
need_bias = True
# Initialize the new module
if isinstance(old_module, LinearLayer_Compress) or isinstance(old_module, torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress):
new_module = old_module
else:
new_module = LinearLayer_Compress(old_module.in_features, old_module.out_features,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module, Conv2dLayer_Compress) or isinstance(old_module, torch.nn.Conv2d):
if isinstance(old_module, Conv2dLayer_Compress):
new_module = old_module
else:
new_module = Conv2dLayer_Compress(old_module.in_channels, old_module.out_channels, old_module.kernel_size, old_module.stride, old_module.padding, \
old_module.dilation, old_module.groups, need_bias, \
old_module.padding_mode).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module, torch.nn.BatchNorm2d):
new_module = BNLayer_Compress(old_module.num_features, old_module.eps, old_module.momentum, old_module.affine,
old_module.track_running_stats).to(old_module.weight.device,
old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
new_module.running_mean.data = old_module.running_mean.data
new_module.running_var.data = old_module.running_var.data
elif isinstance(old_module, Embedding_Compress) or isinstance(old_module, torch.nn.Embedding):
if isinstance(old_module, Embedding_Compress):
new_module = old_module
else:
new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
elif mpu is not None and (isinstance(old_module, ColumnParallelLinear_Compress)
or isinstance(old_module, mpu.ColumnParallelLinear)):
if isinstance(old_module, ColumnParallelLinear_Compress):
new_module = old_module
else:
new_module = ColumnParallelLinear_Compress(mpu,
old_module.input_size,
old_module.output_size,
gather_output=old_module.gather_output,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif mpu is not None and (isinstance(old_module, RowParallelLinear_Compress)
or isinstance(old_module, mpu.RowParallelLinear)):
if isinstance(old_module, RowParallelLinear_Compress):
new_module = old_module
else:
new_module = RowParallelLinear_Compress(mpu,
old_module.input_size,
old_module.output_size,
input_is_parallel=old_module.input_is_parallel,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
else:
new_module = None
if compression_technique is not None and new_module is not None:
for k, v in compression_technique.items():
if k == SPARSE_PRUNING:
if v[SPARSE_PRUNING_ENABLED]:
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], v[SPARSE_PRUNING_METHOD])
elif k == ROW_PRUNING:
if v[ROW_PRUNING_ENABLED]:
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], v[ROW_PRUNING_METHOD])
elif k == HEAD_PRUNING:
if v[HEAD_PRUNING_ENABLED]:
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], v[HEAD_PRUNING_METHOD],
v[HEAD_PRUNING_NUM_HEADS])
elif k == ACTIVATION_QUANTIZATION:
if v[ACTIVATION_QUANTIZATION_ENABLED]:
new_module.enable_activation_quantization(v[ACTIVATION_QUANTIZE_BITS], v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
elif k == WEIGHT_QUANTIZATION:
if v[WEIGHT_QUANTIZE_ENABLED]:
new_module.enable_weight_quantization(v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_TYPE], v[WEIGHT_QUANTIZE_GROUPS])
elif k == CHANNEL_PRUNING:
if v[CHANNEL_PRUNING_ENABLED]:
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], v[CHANNEL_PRUNING_METHOD])
else:
raise NotImplementedError('Compression technique {} is not implemented'.format(k))
# Replace the old module with the new one
recursive_setattr(model, module_name, new_module)
def is_module_compressible(module, mpu=None):
ret = isinstance(module, torch.nn.Linear) or \
isinstance(module, torch.nn.Conv2d) or \
isinstance(module, torch.nn.Embedding) or \
isinstance(module, torch.nn.BatchNorm2d)
if mpu is not None:
ret = ret or isinstance(module, mpu.RowParallelLinear) or isinstance(module, mpu.ColumnParallelLinear)
return ret
def compression_preparation(model, compression_technique_list, mpu):
"""
Prepare the compression techniques of a model.
Args:
model (`torch.nn.Module`)
The model to prepare the compression techniques of.
compression_technique_list (`list`)
The list of compression techniques to prepare the model to.
list[]
"""
# Here we first replace all module with our linear wrapper
for module_name, module in model.named_modules():
if is_module_compressible(module, mpu):
module_replacement(model, module_name, mpu=mpu)
for module_name_lists, _, compression_technique in compression_technique_list:
for mnl in module_name_lists:
for module_name in mnl:
module_replacement(model, module_name, compression_technique)
return model
def fix_compression(model, module_name, compression_technique, mask=None, dim_reduction=False):
"""
Fix the compression technique of a module.
Args:
model (`torch.nn.Module`)
The model to fix the compression technique of.
module_name (`str`)
The name of the module to fix the compression technique of.
compression_technique (`str`)
The compression technique to fix the module to.
"""
# Here we can make things much simpler by just replacing the module
module = recursive_getattr(model, module_name)
for k, v in compression_technique.items():
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[WEIGHT_QUANTIZE_ENABLED]:
return module.fix_weight_quantization()
elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
return module.fix_sparse_pruning_helper()
elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
return module.fix_head_pruning_helper(mask, v[HEAD_PRUNING_NUM_HEADS], dim_reduction=dim_reduction)
elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
def convert_conv1d_to_linear(model, convert_type):
'''
This is a help function to convert conv1d to linear (e.g., convert GPT2 from HF)
'''
if hasattr(model, 'module'):
c_model = model.module
else:
c_model = model
for name, module in c_model.named_modules():
if isinstance(module, convert_type):
old_module = recursive_getattr(c_model, name)
new_module = torch.nn.Linear(old_module.weight.data.size(0),
old_module.weight.data.size(1),
bias=True if old_module.bias is not None else False)
new_module.weight.data = old_module.weight.data.t().contiguous()
if new_module.bias is not None:
new_module.bias.data = old_module.bias.data.view(-1)
recursive_setattr(c_model, name, new_module)
return model
def generate_pruners(config, model):
"""Generate pruners.
Args:
config (`neural_compressor.WeightPruningConfig`)
The object to the class WeightPruningConfig.
model (`torch.nn.module`)
The torch module object to be pruned.
"""
assert nc_pruner is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
from nc_pruner.utils import process_config, parse_to_prune
from nc_pruner.pruners import get_pruner
assert isinstance(model, torch.nn.Module)
pruners_info = process_config(config)
pruners = []
for info in pruners_info:
modules = parse_to_prune(info, model)
if modules == {}:
logger.warning("one pruner hooks no layers, please have a check")
pruners.append(get_pruner(info, modules))
info['modules'] = [key for key in modules.keys()]
info['len_of_modules'] = len(info['modules'])
logger.info(info)
return pruners
def register_on_step_begin(model):
"""Mount on_step_begin to the model.
Args:
model (`torch.nn.module`)
The torch module object to be pruned.
"""
def hook(module, input):
for pruner in module.pruners:
pruner.on_step_begin(0)
hook_handle = model.register_forward_pre_hook(hook)
return hook_handle
def rewrite_optimizer_step(opt: torch.optim.Optimizer):
"""Mount on_before/after_optimizer_step to the optimizer.
Args:
model (`torch.opt.Optimizer`)
The torch optimizer object to be hooked.
"""
def new_step(self, closure=None):
if hasattr(self, "pruners"):
for pruner in self.pruners:
pruner.on_before_optimizer_step()
if closure is not None:
res = self.orig_step(closure)
else:
res = self.orig_step()
if hasattr(self, "pruners"):
for pruner in self.pruners:
pruner.on_after_optimizer_step()
return res
opt.orig_step = opt.step
import types
opt.step = types.MethodType(new_step, opt)
return opt