mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Signed-off-by: Tian, Feng <feng.tian@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
175 lines
8.0 KiB
Python
175 lines
8.0 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from .compress import get_module_name
|
|
from .constants import *
|
|
from .helper import recursive_getattr
|
|
from deepspeed.utils import logger
|
|
|
|
|
|
class compression_scheduler():
|
|
'''
|
|
Used to schedule different compression methods
|
|
'''
|
|
|
|
def __init__(self, model, compression_config):
|
|
self.model = model
|
|
self.compression_config = compression_config
|
|
self.make_init()
|
|
self.training_steps = 0
|
|
self.weight_quantization_enabled = False
|
|
|
|
self.verbose = {
|
|
WEIGHT_QUANTIZATION: False,
|
|
ACTIVATION_QUANTIZATION: False,
|
|
SPARSE_PRUNING: False,
|
|
HEAD_PRUNING: False,
|
|
ROW_PRUNING: False,
|
|
CHANNEL_PRUNING: False
|
|
}
|
|
|
|
def make_init(self):
|
|
self.different_compression_methods = {}
|
|
for method, method_content in self.compression_config.items():
|
|
if LAYER_REDUCTION in method:
|
|
continue
|
|
self.different_compression_methods[method] = {
|
|
TECHNIQUE_ENABLED: False,
|
|
SHARED_PARAMETERS: None,
|
|
DIFFERENT_GROUPS: []
|
|
}
|
|
exist_module_name = set()
|
|
shared_parameters = method_content[SHARED_PARAMETERS]
|
|
self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
|
|
self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters
|
|
|
|
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
|
|
module_name_list = []
|
|
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
|
|
module_name, exist_module_name = get_module_name(group_name,
|
|
self.model,
|
|
key_word,
|
|
exist_module_name,
|
|
verbose=False)
|
|
module_name_list.extend(module_name)
|
|
if module_name_list:
|
|
self.different_compression_methods[method][DIFFERENT_GROUPS].append(
|
|
[group_name, module_name_list,
|
|
method_parameters.copy().pop('params')])
|
|
|
|
def check_weight_quantization(self):
|
|
# check weight quantization
|
|
wq = self.different_compression_methods[WEIGHT_QUANTIZATION]
|
|
if not wq[TECHNIQUE_ENABLED]:
|
|
return
|
|
else:
|
|
shared_parameters = wq[SHARED_PARAMETERS]
|
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
|
|
for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]:
|
|
for module_name in module_name_list:
|
|
module = recursive_getattr(self.model, module_name)
|
|
module.weight_quantization_enabled = True
|
|
|
|
if not self.verbose[WEIGHT_QUANTIZATION]:
|
|
logger.info(f'Weight quantization is enabled at step {self.training_steps}')
|
|
self.weight_quantization_enabled = True
|
|
self.verbose[WEIGHT_QUANTIZATION] = True
|
|
|
|
def check_activation_quantization(self):
|
|
# check activation quantization
|
|
aq = self.different_compression_methods[ACTIVATION_QUANTIZATION]
|
|
if not aq[TECHNIQUE_ENABLED]:
|
|
return
|
|
else:
|
|
shared_parameters = aq[SHARED_PARAMETERS]
|
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
|
|
for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]:
|
|
for module_name in module_name_list:
|
|
module = recursive_getattr(self.model, module_name)
|
|
module.activation_quantization_enabled = True
|
|
if not self.verbose[ACTIVATION_QUANTIZATION]:
|
|
logger.info(f'Activation quantization is enabled at step {self.training_steps}')
|
|
self.verbose[ACTIVATION_QUANTIZATION] = True
|
|
|
|
def check_sparse_pruning(self):
|
|
# check sparse pruning
|
|
sp = self.different_compression_methods[SPARSE_PRUNING]
|
|
if not sp[TECHNIQUE_ENABLED]:
|
|
return
|
|
else:
|
|
shared_parameters = sp[SHARED_PARAMETERS]
|
|
if self.training_steps >= shared_parameters[
|
|
TECHNIQUE_SCHEDULE_OFFSET] and self.training_steps <= shared_parameters[
|
|
TECHNIQUE_SCHEDULE_OFFSET_END]:
|
|
for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
|
|
for module_name in module_name_list:
|
|
module = recursive_getattr(self.model, module_name)
|
|
module.sparse_pruning_enabled = True
|
|
if not self.verbose[SPARSE_PRUNING]:
|
|
logger.info(f'Sparse pruning is enabled at step {self.training_steps}')
|
|
self.verbose[SPARSE_PRUNING] = True
|
|
|
|
def check_head_pruning(self):
|
|
# check head pruning
|
|
hp = self.different_compression_methods[HEAD_PRUNING]
|
|
if not hp[TECHNIQUE_ENABLED]:
|
|
return
|
|
else:
|
|
shared_parameters = hp[SHARED_PARAMETERS]
|
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
|
|
for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]:
|
|
for module_name in module_name_list:
|
|
module = recursive_getattr(self.model, module_name)
|
|
module.head_pruning_enabled = True
|
|
if not self.verbose[HEAD_PRUNING]:
|
|
logger.info(f'Head pruning is enabled at step {self.training_steps}')
|
|
self.verbose[HEAD_PRUNING] = True
|
|
|
|
def check_row_pruning(self):
|
|
# check row pruning
|
|
rp = self.different_compression_methods[ROW_PRUNING]
|
|
if not rp[TECHNIQUE_ENABLED]:
|
|
return
|
|
else:
|
|
shared_parameters = rp[SHARED_PARAMETERS]
|
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
|
|
for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]:
|
|
for module_name in module_name_list:
|
|
module = recursive_getattr(self.model, module_name)
|
|
module.row_pruning_enabled = True
|
|
if not self.verbose[ROW_PRUNING]:
|
|
logger.info(f'Row pruning is enabled at step {self.training_steps}')
|
|
self.verbose[ROW_PRUNING] = True
|
|
|
|
def check_channel_pruning(self):
|
|
# check channel pruning
|
|
cp = self.different_compression_methods[CHANNEL_PRUNING]
|
|
if not cp[TECHNIQUE_ENABLED]:
|
|
return
|
|
else:
|
|
shared_parameters = cp[SHARED_PARAMETERS]
|
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
|
|
for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]:
|
|
for module_name in module_name_list:
|
|
module = recursive_getattr(self.model, module_name)
|
|
module.channel_pruning_enabled = True
|
|
if not self.verbose[CHANNEL_PRUNING]:
|
|
logger.info(f'Channel pruning is enabled at step {self.training_steps}')
|
|
self.verbose[CHANNEL_PRUNING] = True
|
|
|
|
def check_all_modules(self):
|
|
# check all different compression methods we have
|
|
self.check_weight_quantization()
|
|
self.check_activation_quantization()
|
|
self.check_sparse_pruning()
|
|
self.check_head_pruning()
|
|
self.check_row_pruning()
|
|
self.check_channel_pruning()
|
|
|
|
def step(self, step_zero_check=False):
|
|
if not step_zero_check:
|
|
self.training_steps += 1
|
|
self.check_all_modules()
|