Add snip_momentum structured pruning which can support higher sparse ratio with minor accuracy loss (#3300)

Signed-off-by: Tian, Feng <feng.tian@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Tian, Feng
2023-05-11 01:33:48 +08:00
committed by GitHub
parent b31b46c0d1
commit 6938c449de
10 changed files with 181 additions and 12 deletions

View File

@ -11,6 +11,11 @@ from .constants import *
import os
import json
try:
import neural_compressor as nc
except ImportError as e:
nc = None
def check_deepspeed_config(config):
if isinstance(config, dict):
@ -117,6 +122,26 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu)
# For sparse pruning snip_momentum method
shared_parameters = compress_methods[SPARSE_PRUNING][SHARED_PARAMETERS]
if shared_parameters[SPARSE_PRUNING_ENABLED] and \
shared_parameters[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
assert nc is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
from .helper import generate_pruners, register_on_step_begin
from nc import WeightPruningConfig
config = WeightPruningConfig(target_sparsity=1 - shared_parameters[SPARSE_PRUNING_DENSE_RATIO],
pattern=shared_parameters[SPARSE_PRUNING_BLOCK_PATTERN],
pruning_frequency=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE],
start_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET],
end_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_END],
excluded_op_names=shared_parameters[SPARSE_PRUNING_EXCLUDED_MODULES])
pruners = generate_pruners(config, c_model)
c_model.pruners = pruners
register_on_step_begin(c_model)
return model

View File

@ -5,7 +5,7 @@
from .constants import *
import copy
from ..runtime.config_utils import get_scalar_param
from ..runtime.config_utils import get_scalar_param, get_list_param
def get_compression_config(param_dict):
@ -221,15 +221,17 @@ def get_sparse_pruning(param_dict):
# shared parameters
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED] and output[SHARED_PARAMETERS][
SPARSE_PRUNING_METHOD] != SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
), f"Sparse Pruning is enabled and not snip_momentum method, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output
def get_sparse_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
@ -237,10 +239,26 @@ def get_sparse_pruning_shared_parameters(param_dict):
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [
SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK
], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK, SPARSE_PRUNING_METHOD_SNIP_MOMENTUM
], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}, {SPARSE_PRUNING_METHOD_SNIP_MOMENTUM}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
if output[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
output[SPARSE_PRUNING_BLOCK_PATTERN] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_BLOCK_PATTERN,
SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT)
output[SPARSE_PRUNING_DENSE_RATIO] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_DENSE_RATIO,
SPARSE_PRUNING_DENSE_RATIO_DEFAULT)
assert output[SPARSE_PRUNING_DENSE_RATIO] > 0 and output[
SPARSE_PRUNING_DENSE_RATIO] < 1, f"Invalid dense_ratio value. Must be less than 1"
output[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE] = get_scalar_param(
sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT)
output[SPARSE_PRUNING_EXCLUDED_MODULES] = get_list_param(sub_param_dict, SPARSE_PRUNING_EXCLUDED_MODULES,
SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT)
output[SPARSE_PRUNING_SCHEDULE_OFFSET_END] = get_scalar_param(sub_param_dict,
SPARSE_PRUNING_SCHEDULE_OFFSET_END,
output[SPARSE_PRUNING_SCHEDULE_OFFSET])
assert output[SPARSE_PRUNING_SCHEDULE_OFFSET] <= output[
SPARSE_PRUNING_SCHEDULE_OFFSET_END], f"Invalid schedule_offset and schedule_offset_end values"
else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT

View File

@ -12,6 +12,7 @@ SHARED_PARAMETERS = "shared_parameters"
DIFFERENT_GROUPS = "different_groups"
TECHNIQUE_ENABLED = "enabled"
TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset"
TECHNIQUE_SCHEDULE_OFFSET_END = "schedule_offset_end"
DIFFERENT_GROUPS_PARAMETERS = "params"
DIFFERENT_GROUPS_MODULE_SCOPE = "modules"
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*"
@ -111,11 +112,25 @@ SPARSE_PRUNING_METHOD = "method"
SPARSE_PRUNING_METHOD_DEFAULT = "l1"
SPARSE_PRUNING_METHOD_L1 = "l1"
SPARSE_PRUNING_METHOD_TOPK = "topk"
SPARSE_PRUNING_METHOD_SNIP_MOMENTUM = "snip_momentum"
SPARSE_PRUNING_BLOCK_PATTERN = "block_pattern"
SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT = "4x1"
SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE = "schedule_offset_stride"
SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT = 1
SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
SPARSE_PRUNING_SCHEDULE_OFFSET_END = TECHNIQUE_SCHEDULE_OFFSET_END
SPARSE_PRUNING_SCHEDULE_OFFSET_END_DEFAULT = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT
SPARSE_PRUNING_DENSE_RATIO = "dense_ratio"
SPARSE_PRUNING_DENSE_RATIO_DEFAULT = 0.1
SPARSE_PRUNING_EXCLUDED_MODULES = "excluded_modules"
SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT = []
###
# Row Pruning
###

View File

@ -6,6 +6,12 @@
import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
from .constants import *
from deepspeed.utils import logger
try:
from neural_compressor.compression import pruner as nc_pruner
except ImportError as e:
nc_pruner = None
def recursive_getattr(model, module_name):
@ -246,3 +252,71 @@ def convert_conv1d_to_linear(model, convert_type):
recursive_setattr(c_model, name, new_module)
return model
def generate_pruners(config, model):
"""Generate pruners.
Args:
config (`neural_compressor.WeightPruningConfig`)
The object to the class WeightPruningConfig.
model (`torch.nn.module`)
The torch module object to be pruned.
"""
assert nc_pruner is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
from nc_pruner.utils import process_config, parse_to_prune
from nc_pruner.pruners import get_pruner
assert isinstance(model, torch.nn.Module)
pruners_info = process_config(config)
pruners = []
for info in pruners_info:
modules = parse_to_prune(info, model)
if modules == {}:
logger.warning("one pruner hooks no layers, please have a check")
pruners.append(get_pruner(info, modules))
info['modules'] = [key for key in modules.keys()]
info['len_of_modules'] = len(info['modules'])
logger.info(info)
return pruners
def register_on_step_begin(model):
"""Mount on_step_begin to the model.
Args:
model (`torch.nn.module`)
The torch module object to be pruned.
"""
def hook(module, input):
for pruner in module.pruners:
pruner.on_step_begin(0)
hook_handle = model.register_forward_pre_hook(hook)
return hook_handle
def rewrite_optimizer_step(opt: torch.optim.Optimizer):
"""Mount on_before/after_optimizer_step to the optimizer.
Args:
model (`torch.opt.Optimizer`)
The torch optimizer object to be hooked.
"""
def new_step(self, closure=None):
if hasattr(self, "pruners"):
for pruner in self.pruners:
pruner.on_before_optimizer_step()
if closure is not None:
res = self.orig_step(closure)
else:
res = self.orig_step()
if hasattr(self, "pruners"):
for pruner in self.pruners:
pruner.on_after_optimizer_step()
return res
opt.orig_step = opt.step
import types
opt.step = types.MethodType(new_step, opt)
return opt

View File

@ -100,7 +100,9 @@ class compression_scheduler():
return
else:
shared_parameters = sp[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
if self.training_steps >= shared_parameters[
TECHNIQUE_SCHEDULE_OFFSET] and self.training_steps <= shared_parameters[
TECHNIQUE_SCHEDULE_OFFSET_END]:
for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)

View File

@ -314,6 +314,12 @@ class DeepSpeedEngine(Module):
elif self.bfloat16_enabled():
self.optimizer = self._configure_bf16_optimizer(optimizer=None)
# Hook optimizer for snip_momentum pruning
if hasattr(model, 'pruners'):
from ..compression.helper import rewrite_optimizer_step
self.optimizer.pruners = model.pruners
rewrite_optimizer_step(self.optimizer)
# Bookkeeping for sparse support
self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled():

View File

@ -1435,6 +1435,25 @@ Different quantization sets, this is used for different quantization parameters.
}
```
```json
"compression_training": {
"sparse_pruning":{
"shared_parameters":{
"enabled": true,
"schedule_offset": 30,
"schedule_offset_end": 90,
"schedule_offset_stride": 15,
"method": "snip_momentum",
"block_pattern": "4x1",
"dense_ratio": 0.4,
"excluded_modules": ['classifier', 'pooler']
},
"different_groups":{
}
}
}
```
<i>**shared_parameters**</i>: [dictionary]
Shared parameters for all sparse pruning groups.
@ -1443,11 +1462,17 @@ Shared parameters for all sparse pruning groups.
| ----- | ----- | ----- |
| <i>**enabled**</i>: [boolean] | Enable sparse pruning or not. | `false` |
| <i>**schedule_offset**</i>: [integer] | Enable sparse pruning after scheduled steps (can be treated as warmup steps). | `0` |
| <i>**method**</i>: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` |
| <i>**schedule_offset_end**</i>: [integer] | Disable sparse pruning after scheduled steps, mandotory for `snip_momentum`. | `0` |
| <i>**schedule_offset_stride**</i>: [integer] | The stride of pruning on training steps, mandotory for `snip_momentum`. | `"1"` |
| <i>**method**</i>: [string] | Choose different pruning methods, l1 (static, magnitude based), topk (dynamic, learnable) or snip_momentum (structured pruning). | `"l1"` |
| <i>**block_pattern**</i>: [string] | Choose different structured pruning block patterns, NxM or N:M (N and M are integers). For instance, "4x1" or "2:4" are common block patterns, mandotory for `snip_momentum`. | `"4x1"` |
| <i>**dense_ratio**</i>: [float] | Used to get the targeted global sparsity ratio, mandotory for `snip_momentum`. | `"0.1"` |
| <i>**excluded_modules**</i>: [list] | Excluded pruning scope on some special modules like output layer. | `[]` |
<i>**different_groups**</i>: [dictionary]
Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
Note for `snip_momentum` method, you can leave it as empty.
| Fields | Value | Default |
| ----- | ----- | ----- |

View File

@ -158,7 +158,7 @@ Pruning aims to reduce the number of parameters and operations involved in gener
| **Method** | **Type** |
| --------------------- | ------------ |
| [Sparse pruning](#141-sparse-pruning) | Unstructured |
| [Sparse pruning](#141-sparse-pruning) | Unstructured and Structured |
| [Row pruning](#142-row-pruning) | Structured |
| [Head pruning](#143-head-pruning) | Structured |
| [Channel pruning](#144-channel-pruning) | Structured |
@ -166,7 +166,7 @@ Pruning aims to reduce the number of parameters and operations involved in gener
#### 1.4.1 Sparse Pruning
**What is sparse pruning**
Sparse pruning means we set some of the elements in each weight matrix with zero values. There is no structure pattern in the zero values. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626).
Sparse pruning means we set some of the elements in each weight matrix with zero values. Relying on the pruning method user chosen, the zero values may have structured pattern or unstructured pattern. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626). Another way to perform pruning is based on the weights' effect to the loss function when they are masked, see for instance [this paper](https://arxiv.org/abs/1810.02340).
**When to use sparse pruning**
@ -178,11 +178,13 @@ Sparse pruning can be enabled and configured using the DeepSpeed config JSON fil
(1)`schedule_offset`, we empirically find that when using `method: topk`, its better to set the `schedule_offset` to a large value such as 10% of the total training steps.
(2)`method`, we support L1 norm and topk methods. Users are welcome to contribute more methods.
(2)`method`, we support L1 norm, topk and snip_momentum methods. Users are welcome to contribute more methods.
(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc.
(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc. Note this is not needed for snip_momentum method.
(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet.
(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet. for structured sparse pruning like snip_momentum, the dense ratio should be specified in shared_parameters and is used to calculate the global sparsity ratio.
(5)`frequency`, `block_pattern` and `schedule_offset_end`, they are used to specify the pruning frequency on steps, the block-wise pruning pattern (NxM and N in M), and the end steps for pruning. For snip_momentum method, these configurations are mandotory.
The client code change is the same as [weight quantization](#12-weight-quantization).

View File

@ -0,0 +1 @@
neural-compressor==2.1.0

View File

@ -65,6 +65,7 @@ extras_require = {
'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'),
'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'),
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'),
'inf': fetch_requirements('requirements/requirements-inf.txt'),
'sd': fetch_requirements('requirements/requirements-sd.txt')
}