mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
I am trying to give some test files better owner labels than `module: unknown`. I am not sure them, but they seem pretty reasonable Pull Request resolved: https://github.com/pytorch/pytorch/pull/163203 Approved by: https://github.com/jcaip
1095 lines
37 KiB
Python
1095 lines
37 KiB
Python
# Owner(s): ["module: sparse"]
|
|
import copy
|
|
import random
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.ao.pruning._experimental.pruner import (
|
|
BaseStructuredSparsifier,
|
|
FakeStructuredSparsity,
|
|
FPGMPruner,
|
|
LSTMSaliencyPruner,
|
|
SaliencyPruner,
|
|
)
|
|
from torch.nn.utils import parametrize
|
|
from torch.testing._internal.common_pruning import (
|
|
Conv2dActivation,
|
|
Conv2dBias,
|
|
Conv2dPadBias,
|
|
Conv2dPool,
|
|
Conv2dPoolFlatten,
|
|
Conv2dPoolFlattenFunctional,
|
|
LinearActivation,
|
|
LinearActivationFunctional,
|
|
LinearBias,
|
|
LSTMLayerNormLinearModel,
|
|
LSTMLinearModel,
|
|
rows_are_subset,
|
|
SimpleConv2d,
|
|
SimpleLinear,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
raise_on_run_directly,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
DEVICES = {
|
|
torch.device("cpu"),
|
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
|
}
|
|
|
|
|
|
class SimplePruner(BaseStructuredSparsifier):
|
|
def update_mask(self, module, tensor_name, **kwargs):
|
|
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
|
|
|
|
|
|
class ImplementedPruner(BaseStructuredSparsifier):
|
|
def update_mask(self, module, tensor_name, **kwargs):
|
|
"""Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning"""
|
|
num_rows = len(module.parametrizations[tensor_name][0].mask)
|
|
prune = random.sample(list(range(num_rows)), num_rows // 3)
|
|
module.parametrizations[tensor_name][0].mask[prune] = False
|
|
|
|
|
|
class BottomHalfLSTMPruner(BaseStructuredSparsifier):
|
|
"""
|
|
Pruner that will remove the bottom half of the rows.
|
|
This is primarily meant for testing purposes
|
|
"""
|
|
|
|
def update_mask(self, module, tensor_name, **kwargs):
|
|
for p in getattr(module.parametrizations, tensor_name):
|
|
if isinstance(p, FakeStructuredSparsity):
|
|
mask = p.mask
|
|
masks = torch.split(mask, len(mask) // 4)
|
|
for small in masks:
|
|
num = len(small)
|
|
small[num // 2 :] = False
|
|
new_mask = torch.cat(masks)
|
|
mask.data = new_mask.data
|
|
|
|
|
|
class TestSaliencyPruner(TestCase):
|
|
def test_saliency_pruner_update_mask(self):
|
|
"""Test that we prune out the row with the lowest saliency (first row)"""
|
|
model = SimpleLinear()
|
|
with torch.no_grad():
|
|
model.linear1.weight = nn.Parameter(
|
|
torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
|
|
)
|
|
pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}]
|
|
pruner = SaliencyPruner({})
|
|
|
|
pruner.prepare(model, pruning_config)
|
|
pruner.enable_mask_update = True
|
|
pruner.step()
|
|
pruned_model = pruner.prune()
|
|
|
|
expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]])
|
|
pruned = pruned_model.linear1.weight
|
|
|
|
assert expected.shape == pruned.shape
|
|
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
|
|
|
def test_lstm_saliency_pruner_update_mask(self):
|
|
model = LSTMLinearModel(
|
|
input_dim=2,
|
|
hidden_dim=2,
|
|
output_dim=2,
|
|
num_layers=1,
|
|
)
|
|
|
|
manual_weights = torch.Tensor(
|
|
[[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]]
|
|
)
|
|
|
|
with torch.no_grad():
|
|
model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
|
|
model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights))
|
|
model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0])
|
|
model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0])
|
|
|
|
config = [
|
|
{"tensor_fqn": "lstm.weight_ih_l0"},
|
|
{"tensor_fqn": "lstm.weight_hh_l0"},
|
|
]
|
|
lstm_input = torch.ones((1, 2))
|
|
fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5})
|
|
fx_pruner.prepare(model, config)
|
|
fx_pruner.enable_mask_update = True
|
|
fx_pruner.step()
|
|
|
|
model.eval()
|
|
pruned_model = fx_pruner.prune()
|
|
pruned_model.eval()
|
|
|
|
# make sure both models run
|
|
model(lstm_input)
|
|
pruned_model(lstm_input)
|
|
|
|
# make sure lowest saliency rows are pruned
|
|
expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]])
|
|
pruned = model.lstm.weight_ih_l0
|
|
assert expected.shape == pruned.shape
|
|
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
|
|
|
expected = torch.Tensor([[2], [2], [-2], [-2]])
|
|
pruned = model.lstm.weight_hh_l0
|
|
assert expected.shape == pruned.shape
|
|
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
|
|
|
expected = torch.Tensor([2, 2, -2, -2])
|
|
for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]:
|
|
assert expected.shape == pruned.shape
|
|
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
|
|
|
|
|
class TestBaseStructuredSparsifier(TestCase):
|
|
def _check_pruner_prepared(self, model, pruner, device):
|
|
for config in pruner.groups:
|
|
module = config["module"]
|
|
assert module.weight.device.type == device.type
|
|
# Check mask exists
|
|
assert config["tensor_fqn"] in pruner.state
|
|
# Check parametrization exists and is correct
|
|
assert parametrize.is_parametrized(module)
|
|
assert hasattr(module, "parametrizations")
|
|
# Assume that this is the 1st/only parametrization
|
|
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
|
|
|
|
def _check_pruner_valid_before_step(self, model, pruner, device):
|
|
for config in pruner.groups:
|
|
modules = []
|
|
if type(config["module"]) is tuple:
|
|
modules.extend(config["module"])
|
|
else:
|
|
module = config["module"]
|
|
modules.append(module)
|
|
for module in modules:
|
|
assert module.weight.device.type == device.type
|
|
assert module.parametrizations.weight[0].mask.dtype == torch.bool
|
|
|
|
def _check_pruner_valid_after_step(self, model, pruner, mask, device):
|
|
for config in pruner.groups:
|
|
modules = []
|
|
if type(config["module"]) is tuple:
|
|
modules.extend(config["module"])
|
|
else:
|
|
module = config["module"]
|
|
modules.append(module)
|
|
for module in modules:
|
|
assert module.weight.device.type == device.type
|
|
total = module.parametrizations.weight[0].mask.numel()
|
|
assert (
|
|
module.parametrizations.weight[0].mask.count_nonzero()
|
|
== total - mask
|
|
)
|
|
|
|
def _test_constructor_on_device(self, model, device):
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"BaseStructuredSparsifier.*update_mask",
|
|
BaseStructuredSparsifier,
|
|
)
|
|
model1 = copy.deepcopy(model).to(device)
|
|
pruner = SimplePruner(None)
|
|
pruner.prepare(model1, None)
|
|
pruner.enable_mask_update = True
|
|
for g in pruner.groups:
|
|
module = g["module"]
|
|
assert module.weight.device.type == device.type
|
|
assert len(pruner.groups) == 5
|
|
pruner.step()
|
|
# Can instantiate the model with configs
|
|
model2 = copy.deepcopy(model).to(device)
|
|
pruner = SimplePruner({"test": 3})
|
|
pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}])
|
|
assert len(pruner.groups) == 1
|
|
assert pruner.groups[0]["module_fqn"] == "seq.0"
|
|
assert "test" in pruner.groups[0]
|
|
assert pruner.groups[0]["test"] == 3
|
|
|
|
def test_constructor(self):
|
|
model = SimpleLinear()
|
|
for device in DEVICES:
|
|
self._test_constructor_on_device(model, torch.device(device))
|
|
|
|
def _test_prepare_linear_on_device(self, model, device):
|
|
model = copy.deepcopy(model).to(device)
|
|
x = torch.ones(128, 7, device=device)
|
|
pruner = SimplePruner(None)
|
|
pruner.prepare(model, None)
|
|
self._check_pruner_prepared(model, pruner, device)
|
|
assert model(x).shape == (128, 10)
|
|
|
|
def test_prepare_linear(self):
|
|
models = [
|
|
SimpleLinear(),
|
|
LinearBias(),
|
|
LinearActivation(),
|
|
LinearActivationFunctional(),
|
|
] # without and with bias
|
|
for device in DEVICES:
|
|
for model in models:
|
|
self._test_prepare_linear_on_device(model, torch.device(device))
|
|
|
|
def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device):
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
pruner = SimplePruner(None)
|
|
pruner.prepare(model, config)
|
|
self._check_pruner_prepared(model, pruner, device)
|
|
assert model(x).shape == expected_shape
|
|
|
|
def test_prepare_conv2d(self):
|
|
models = [
|
|
SimpleConv2d(),
|
|
Conv2dBias(),
|
|
Conv2dActivation(),
|
|
Conv2dPadBias(),
|
|
Conv2dPool(),
|
|
]
|
|
shapes = [
|
|
(1, 52, 20, 20),
|
|
(1, 52, 18, 18),
|
|
(1, 52, 18, 18),
|
|
(1, 52, 24, 24),
|
|
(1, 52, 3, 3),
|
|
]
|
|
configs = [None, None, None, None, None]
|
|
for device in DEVICES:
|
|
for model, shape, config in zip(models, shapes, configs):
|
|
model = model.to(device)
|
|
self._test_prepare_conv2d_on_device(
|
|
model, shape, config, torch.device(device)
|
|
)
|
|
|
|
def _test_step_linear_on_device(self, model, device):
|
|
model = model.to(device)
|
|
pruner = SimplePruner(None)
|
|
pruner.prepare(model, None)
|
|
pruner.enable_mask_update = True
|
|
self._check_pruner_valid_before_step(model, pruner, device)
|
|
pruner.step()
|
|
self._check_pruner_valid_after_step(model, pruner, 1, device)
|
|
|
|
def test_step_linear(self):
|
|
models = [
|
|
SimpleLinear(),
|
|
LinearBias(),
|
|
LinearActivation(),
|
|
LinearActivationFunctional(),
|
|
]
|
|
for device in DEVICES:
|
|
for model in models:
|
|
self._test_step_linear_on_device(model, torch.device(device))
|
|
|
|
def _test_step_conv2d_on_device(self, model, expected_shape, config, device):
|
|
model = model.to(device)
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
pruner = SimplePruner(None)
|
|
pruner.prepare(model, config)
|
|
pruner.enable_mask_update = True
|
|
self._check_pruner_valid_before_step(model, pruner, device)
|
|
pruner.step()
|
|
self._check_pruner_valid_after_step(model, pruner, 1, device)
|
|
assert model(x).shape == expected_shape
|
|
|
|
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
|
|
def test_step_conv2d(self):
|
|
models = [
|
|
SimpleConv2d(),
|
|
Conv2dBias(),
|
|
Conv2dActivation(),
|
|
Conv2dPadBias(),
|
|
Conv2dPool(),
|
|
]
|
|
shapes = [
|
|
(1, 52, 20, 20),
|
|
(1, 52, 18, 18),
|
|
(1, 52, 18, 18),
|
|
(1, 52, 24, 24),
|
|
(1, 52, 3, 3),
|
|
]
|
|
configs = [None, None, None, None, None]
|
|
for device in DEVICES:
|
|
for model, shape, config in zip(models, shapes, configs):
|
|
self._test_step_conv2d_on_device(
|
|
model, shape, config, torch.device(device)
|
|
)
|
|
|
|
def _check_pruner_pruned(self, model, pruner, device):
|
|
for config in pruner.groups:
|
|
module = config["module"]
|
|
assert not hasattr(module, "parametrizations")
|
|
assert not hasattr(module, "mask")
|
|
|
|
def _test_linear_on_device(
|
|
self, model, config, expected_shape, device, also_prune_bias
|
|
):
|
|
model = model.to(device)
|
|
model.eval()
|
|
num_original_params = sum(p.numel() for p in model.parameters())
|
|
x = torch.ones(128, 7, device=device)
|
|
|
|
pruner = ImplementedPruner({"prune_bias": also_prune_bias})
|
|
pruner.prepare(model, config)
|
|
pruner.enable_mask_update = True
|
|
pruner.step()
|
|
|
|
y_expected = model(x)
|
|
|
|
assert y_expected.shape == (128, 10)
|
|
self._check_pruner_prepared(model, pruner, device)
|
|
|
|
# Pruning step
|
|
pruned = pruner.prune()
|
|
y_pruned = pruned(x)
|
|
num_pruned_params = sum(p.numel() for p in pruned.parameters())
|
|
|
|
assert y_pruned.shape == expected_shape
|
|
self._check_pruner_pruned(model, pruner, device)
|
|
if y_pruned.shape == y_expected.shape:
|
|
assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all()
|
|
assert num_pruned_params < num_original_params
|
|
|
|
def test_prune_linear_linear(self):
|
|
r"""test pruning linear-> linear modules"""
|
|
configs, shapes = [], []
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
]
|
|
)
|
|
shapes.append((128, 10))
|
|
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
{"tensor_fqn": "linear1.weight"},
|
|
]
|
|
)
|
|
shapes.append((128, 10))
|
|
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
]
|
|
)
|
|
shapes.append((128, 10))
|
|
for device in DEVICES:
|
|
for also_prune_bias in [True, False]:
|
|
for config, shape in zip(configs, shapes):
|
|
self._test_linear_on_device(
|
|
SimpleLinear(),
|
|
config,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_linear_bias_linear(self):
|
|
# linear(bias) -> linear(no bias)
|
|
configs, shapes = [], []
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
]
|
|
)
|
|
shapes.append((128, 10))
|
|
|
|
# linear(bias) -> linear(bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
{"tensor_fqn": "seq.3.weight"},
|
|
]
|
|
)
|
|
shapes.append((128, 10))
|
|
|
|
# linear(no bias) -> linear(bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
]
|
|
)
|
|
shapes.append((128, 10))
|
|
|
|
for device in DEVICES:
|
|
for also_prune_bias in [True, False]:
|
|
for config, shape in zip(configs, shapes):
|
|
self._test_linear_on_device(
|
|
LinearBias(),
|
|
config,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_linear_activation_linear(self):
|
|
config = [
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
{"tensor_fqn": "seq.4.weight"},
|
|
{"tensor_fqn": "linear1.weight"},
|
|
]
|
|
shape = (128, 10)
|
|
|
|
for device in DEVICES:
|
|
for also_prune_bias in [True, False]:
|
|
# test version with nn.Modules
|
|
self._test_linear_on_device(
|
|
LinearActivation(),
|
|
config,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
# test functional version
|
|
self._test_linear_on_device(
|
|
LinearActivationFunctional(),
|
|
config,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def _test_conv2d_on_device(
|
|
self, model, config, x, expected_shape, device, also_prune_bias
|
|
):
|
|
model = model.to(device)
|
|
num_original_params = sum(p.numel() for p in model.parameters())
|
|
model.eval()
|
|
|
|
pruner = ImplementedPruner({"prune_bias": also_prune_bias})
|
|
pruner.prepare(model, config)
|
|
pruner.enable_mask_update = True
|
|
pruner.step()
|
|
|
|
y_expected = model(x)
|
|
assert y_expected.shape == expected_shape
|
|
|
|
self._check_pruner_prepared(model, pruner, device)
|
|
|
|
# Fusion step
|
|
pruned = pruner.prune()
|
|
y_pruned = pruned(x)
|
|
num_pruned_params = sum(p.numel() for p in pruned.parameters())
|
|
|
|
assert y_pruned.shape == expected_shape
|
|
self._check_pruner_pruned(model, pruner, device)
|
|
if y_pruned.shape == y_expected.shape:
|
|
# TODO This rtol is a little high, need to double check if something specific is causing this to fail
|
|
assert torch.isclose(
|
|
y_expected,
|
|
y_pruned,
|
|
rtol=1e-3,
|
|
atol=1e-3,
|
|
).all(), f"fail for {type(model)}"
|
|
# only time this should be equal is when all layers have padding and we can't prune
|
|
assert num_pruned_params <= num_original_params
|
|
|
|
def test_prune_conv2d_conv2d(self):
|
|
configs, shapes = [], []
|
|
# all within sequential blocks
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 20, 20))
|
|
# prune across sequential blocks
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
{"tensor_fqn": "conv2d1.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 20, 20))
|
|
|
|
for device in DEVICES:
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
for also_prune_bias in [True, False]:
|
|
for config, shape in zip(configs, shapes):
|
|
self._test_conv2d_on_device(
|
|
SimpleConv2d(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_conv2d_bias_conv2d(self):
|
|
# Conv2d with Bias and no Activation
|
|
configs, shapes = [], []
|
|
# conv2d(bias) -> conv2d(bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
# conv2d(no bias) -> conv2d(bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
{"tensor_fqn": "conv2d1.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
# conv2d(bias) -> conv2d(no bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.1.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
for device in DEVICES:
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
for also_prune_bias in [True, False]:
|
|
for config, shape in zip(configs, shapes):
|
|
self._test_conv2d_on_device(
|
|
Conv2dBias(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_conv2d_activation_conv2d(self):
|
|
# Conv2d with Activation and no Bias
|
|
configs, shapes = [], []
|
|
|
|
# conv2d(no bias) -> activation -> conv2d(no bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.4.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
# conv2d(bias) -> activation -> conv2d(bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
# conv2d(bias) -> activation -> conv2d(no bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
{"tensor_fqn": "seq.4.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
# conv2d(no bias) -> activation -> conv2d(bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "conv2d1.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 18, 18))
|
|
|
|
for device in DEVICES:
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
for also_prune_bias in [True, False]:
|
|
for config, shape in zip(configs, shapes):
|
|
self._test_conv2d_on_device(
|
|
Conv2dActivation(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_conv2d_padding_conv2d(self):
|
|
# Conv2d with Padded layers after Bias layers
|
|
configs, shapes = [], []
|
|
|
|
# conv(padded, bias) -> conv(padded, bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.4.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 24, 24))
|
|
|
|
# conv(no bias, no pad) -> conv(padded, bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.2.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 24, 24))
|
|
|
|
# conv(padded, bias) -> conv ( no bias ,no pad)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 24, 24))
|
|
# conv(pad, bias) -> conv(no pad, bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.6.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 24, 24))
|
|
# conv(no pad, bias) -> conv(pad, bias)
|
|
configs.append(
|
|
[
|
|
{"tensor_fqn": "seq.8.weight"},
|
|
]
|
|
)
|
|
shapes.append((1, 52, 24, 24))
|
|
|
|
for device in DEVICES:
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
for also_prune_bias in [True, False]:
|
|
for config, shape in zip(configs, shapes):
|
|
self._test_conv2d_on_device(
|
|
Conv2dPadBias(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_conv2d_pool_conv2d(self):
|
|
# Conv2d with Pooling layers
|
|
config = [
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.3.weight"},
|
|
{"tensor_fqn": "conv2d1.weight"},
|
|
{"tensor_fqn": "conv2d2.weight"},
|
|
]
|
|
shape = (1, 52, 3, 3)
|
|
|
|
for device in DEVICES:
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
for also_prune_bias in [True, False]:
|
|
self._test_conv2d_on_device(
|
|
Conv2dPool(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
|
|
def test_complex_conv2d(self):
|
|
"""Test fusion for models that contain Conv2d & Linear modules.
|
|
Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add"""
|
|
config = [
|
|
{"tensor_fqn": "seq.0.weight"},
|
|
{"tensor_fqn": "seq.3.weight"},
|
|
{"tensor_fqn": "conv2d1.weight"},
|
|
{"tensor_fqn": "conv2d2.weight"},
|
|
]
|
|
shape = (1, 13)
|
|
|
|
for device in DEVICES:
|
|
x = torch.ones((1, 1, 28, 28), device=device)
|
|
for also_prune_bias in [True, False]:
|
|
self._test_conv2d_on_device(
|
|
Conv2dPoolFlattenFunctional(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
self._test_conv2d_on_device(
|
|
Conv2dPoolFlatten(),
|
|
config,
|
|
x,
|
|
shape,
|
|
torch.device(device),
|
|
also_prune_bias,
|
|
)
|
|
|
|
def test_prune_lstm_linear_multiple_layer(self):
|
|
"""
|
|
Test fusion support for LSTM(multi-layer) -> Linear
|
|
"""
|
|
model = LSTMLinearModel(
|
|
input_dim=8,
|
|
hidden_dim=8,
|
|
output_dim=8,
|
|
num_layers=2,
|
|
)
|
|
|
|
config = [
|
|
{"tensor_fqn": "lstm.weight_ih_l0"},
|
|
{"tensor_fqn": "lstm.weight_hh_l0"},
|
|
{"tensor_fqn": "lstm.weight_ih_l1"},
|
|
{"tensor_fqn": "lstm.weight_hh_l1"},
|
|
]
|
|
|
|
lstm_input = torch.ones((1, 8))
|
|
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
|
|
fx_pruner.prepare(model, config)
|
|
|
|
fx_pruner.enable_mask_update = True
|
|
fx_pruner.step()
|
|
|
|
model.eval()
|
|
_, _ = model(lstm_input)
|
|
pruned_model = fx_pruner.prune()
|
|
pruned_model.eval()
|
|
_, _ = pruned_model(lstm_input)
|
|
|
|
expected_params = dict(model.named_parameters())
|
|
for name, param in model.named_parameters():
|
|
assert name in expected_params
|
|
# We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
|
|
# Instead we check that the weights of the new LSTM are a subset of the weights of
|
|
# the old LSTM
|
|
assert rows_are_subset(param, expected_params[name])
|
|
del expected_params[name]
|
|
|
|
# assert we haven't deleted any keys
|
|
assert len(expected_params) == 0
|
|
|
|
def test_prune_lstm_linear_single_layer(self):
|
|
"""
|
|
Test fusion support for LSTM (single-layer) -> Linear
|
|
"""
|
|
model = LSTMLinearModel(
|
|
input_dim=8,
|
|
hidden_dim=8,
|
|
output_dim=8,
|
|
num_layers=1,
|
|
)
|
|
|
|
config = [
|
|
{"tensor_fqn": "lstm.weight_ih_l0"},
|
|
{"tensor_fqn": "lstm.weight_hh_l0"},
|
|
]
|
|
|
|
lstm_input = torch.ones((1, 8))
|
|
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
|
|
fx_pruner.prepare(model, config)
|
|
fx_pruner.enable_mask_update = True
|
|
fx_pruner.step()
|
|
model.eval()
|
|
|
|
out_expected, lstm_out_expected = model(lstm_input)
|
|
pruned_model = fx_pruner.prune()
|
|
pruned_model.eval()
|
|
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
|
|
_, c = lstm_out_expected.size()
|
|
|
|
# We cannot check that y_expected == y_pruned as usual because
|
|
# zeros vs. missing elements yield different numerical results.
|
|
# Instead that we check that the pruned elements are the first half of the results
|
|
# since we are using a BottomHalfLSTMPruner
|
|
assert torch.isclose(
|
|
lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
|
|
).all()
|
|
# also check that output of linear is the same shape, this means we've resized
|
|
# linear columns correctly.
|
|
assert out_expected.shape == out_pruned.shape
|
|
|
|
def test_prune_lstm_layernorm_linear_multiple_layer(self):
|
|
"""
|
|
Test fusion support for LSTM(multi-layer) -> Linear
|
|
"""
|
|
model = LSTMLayerNormLinearModel(
|
|
input_dim=8,
|
|
output_dim=8,
|
|
hidden_dim=8,
|
|
num_layers=2,
|
|
)
|
|
|
|
config = [
|
|
{"tensor_fqn": "lstm.weight_ih_l0"},
|
|
{"tensor_fqn": "lstm.weight_hh_l0"},
|
|
{"tensor_fqn": "lstm.weight_ih_l1"},
|
|
{"tensor_fqn": "lstm.weight_hh_l1"},
|
|
]
|
|
|
|
lstm_input = torch.ones((1, 8))
|
|
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
|
|
fx_pruner.prepare(model, config)
|
|
|
|
fx_pruner.enable_mask_update = True
|
|
fx_pruner.step()
|
|
|
|
model.eval()
|
|
_, _ = model(lstm_input)
|
|
pruned_model = fx_pruner.prune()
|
|
pruned_model.eval()
|
|
_, _ = pruned_model(lstm_input)
|
|
|
|
expected_params = dict(model.named_parameters())
|
|
for name, param in model.named_parameters():
|
|
assert name in expected_params
|
|
# We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
|
|
# Instead we check that the weights of the new LSTM are a subset of the weights of
|
|
# the old LSTM
|
|
assert rows_are_subset(param, expected_params[name])
|
|
del expected_params[name]
|
|
|
|
# assert we haven't deleted any keys
|
|
assert len(expected_params) == 0
|
|
|
|
def test_prune_lstm_layernorm_linear_single_layer(self):
|
|
"""
|
|
Test fusion support for LSTM (single-layer) -> Linear
|
|
"""
|
|
model = LSTMLinearModel(
|
|
input_dim=8,
|
|
hidden_dim=8,
|
|
output_dim=8,
|
|
num_layers=1,
|
|
)
|
|
|
|
config = [
|
|
{"tensor_fqn": "lstm.weight_ih_l0"},
|
|
{"tensor_fqn": "lstm.weight_hh_l0"},
|
|
]
|
|
|
|
lstm_input = torch.ones((1, 8))
|
|
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
|
|
fx_pruner.prepare(model, config)
|
|
fx_pruner.enable_mask_update = True
|
|
fx_pruner.step()
|
|
model.eval()
|
|
|
|
out_expected, lstm_out_expected = model(lstm_input)
|
|
pruned_model = fx_pruner.prune()
|
|
pruned_model.eval()
|
|
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
|
|
_, c = lstm_out_expected.size()
|
|
|
|
# We cannot check that y_expected == y_pruned as usual because
|
|
# zeros vs. missing elements yield different numerical results.
|
|
# Instead that we check that the pruned elements are the first half of the results
|
|
# since we are using a BottomHalfLSTMPruner
|
|
assert torch.isclose(
|
|
lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
|
|
).all()
|
|
# also check that output of linear is the same shape, this means we've resized
|
|
# linear columns correctly.
|
|
assert out_expected.shape == out_pruned.shape
|
|
|
|
|
|
class TestFPGMPruner(TestCase):
|
|
"""
|
|
Test case for the implementation of paper:
|
|
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
|
|
"""
|
|
|
|
class SimpleConvFPGM(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv2d1 = nn.Conv2d(
|
|
in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
|
|
)
|
|
# Manually set the filter weights for demonstration purposes
|
|
"""
|
|
Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
|
|
Different from the norm-based decision that prunes filter with value 0.1,
|
|
FPGM will prune the one with value 2.0.
|
|
"""
|
|
weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter
|
|
weights = weights[:, None, None, None] # broadcasting
|
|
self.conv2d1.weight.data.copy_(
|
|
torch.ones(self.conv2d1.weight.shape) * weights
|
|
)
|
|
|
|
# Second Convolutional Layer
|
|
self.conv2d2 = nn.Conv2d(
|
|
in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False
|
|
)
|
|
weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
|
|
weights = weights[:, None, None, None]
|
|
self.conv2d2.weight.data.copy_(
|
|
torch.ones(self.conv2d2.weight.shape) * weights
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv2d1(x)
|
|
x = self.conv2d2(x)
|
|
return x
|
|
|
|
def test_compute_distance(self, device="cpu"):
|
|
"""Test the distance computation function"""
|
|
model = TestFPGMPruner.SimpleConvFPGM().to(device)
|
|
pruner = FPGMPruner(0.3)
|
|
dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
|
|
|
|
# compute the distance matrix using torch.cdist
|
|
flattened_filters = torch.Tensor(
|
|
[
|
|
[
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
3.0000,
|
|
],
|
|
[
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
2.0000,
|
|
],
|
|
[
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
0.1000,
|
|
],
|
|
]
|
|
)
|
|
|
|
"""
|
|
Expected distance matrix should have the following values:
|
|
[0.0000, 3.0000, 8.7000],
|
|
[3.0000, 0.0000, 5.7000],
|
|
[8.7000, 5.7000, 0.0000],
|
|
the distance should therefore be:
|
|
[11.7000, 8.7000, 14.4000]
|
|
"""
|
|
expected_dist_matrix_conv1 = torch.cdist(
|
|
flattened_filters, flattened_filters, p=2
|
|
)
|
|
expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
|
|
assert torch.isclose(
|
|
dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07
|
|
).all()
|
|
|
|
def _test_update_mask_on_single_layer(self, expected_conv1, device):
|
|
"""Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
|
|
# test pruning with one layer of conv2d
|
|
model = TestFPGMPruner.SimpleConvFPGM().to(device)
|
|
x = torch.ones((1, 1, 32, 32), device=device)
|
|
pruner = FPGMPruner(0.3)
|
|
config = [{"tensor_fqn": "conv2d1.weight"}]
|
|
pruner.prepare(model, config)
|
|
pruner.enable_mask_update = True
|
|
pruner.step()
|
|
assert (
|
|
pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item()
|
|
is not False
|
|
), "do not prune the least-norm filter"
|
|
|
|
# fusion step
|
|
pruned_model = pruner.prune()
|
|
|
|
pruned_y = pruned_model(x)
|
|
# assert shapes
|
|
expected_conv1 = expected_conv1.to(device)
|
|
assert pruned_y.shape == (1, 4, 32, 32)
|
|
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
|
|
assert pruned_model.conv2d2.weight.shape == (
|
|
4,
|
|
2,
|
|
3,
|
|
3,
|
|
), "conv2d2 should have input channel pruned"
|
|
# assert value
|
|
assert torch.isclose(
|
|
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
|
|
).all()
|
|
|
|
def _test_update_mask_on_multiple_layer(
|
|
self, expected_conv1, expected_conv2, device
|
|
):
|
|
# the second setting
|
|
model = TestFPGMPruner.SimpleConvFPGM().to(device)
|
|
x = torch.ones((1, 1, 32, 32), device=device)
|
|
pruner = FPGMPruner(0.3)
|
|
config = [
|
|
{"tensor_fqn": "conv2d1.weight"},
|
|
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5},
|
|
]
|
|
pruner.prepare(model, config)
|
|
pruner.enable_mask_update = True
|
|
pruner.step()
|
|
# Get the masks for the two least-norm filters
|
|
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
|
|
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
|
|
# Check if either of the least-norm filters is not pruned
|
|
assert mask1.item() is not False or mask2.item() is not False, (
|
|
"Do not prune all least-norm filters"
|
|
)
|
|
|
|
# fusion step
|
|
pruned_model = pruner.prune()
|
|
pruned_y = pruned_model(x)
|
|
# assert shapes
|
|
expected_conv1 = expected_conv1.to(device)
|
|
expected_conv2 = expected_conv2.to(device)
|
|
assert pruned_y.shape == (1, 2, 32, 32)
|
|
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
|
|
assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
|
|
# assert values
|
|
assert torch.isclose(
|
|
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
|
|
).all()
|
|
assert torch.isclose(
|
|
pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07
|
|
).all()
|
|
|
|
def test_update_mask(self):
|
|
weights = torch.tensor([3.0, 0.1])
|
|
expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None]
|
|
|
|
weights = torch.tensor([7.0, 0.4])
|
|
expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None]
|
|
|
|
for device in DEVICES:
|
|
self._test_update_mask_on_single_layer(expected_conv1, device)
|
|
self._test_update_mask_on_multiple_layer(
|
|
expected_conv1, expected_conv2, device
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_ao_sparsity.py")
|