Files
pytorch/test/ao/sparsity/test_structured_sparsifier.py
Catherine Lee 4908fb53c3 [testing] Add test owner labels for some ao sparse tests (#163203)
I am trying to give some test files better owner labels than `module: unknown`.  I am not sure them, but they seem pretty reasonable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163203
Approved by: https://github.com/jcaip
2025-09-18 16:08:13 +00:00

1095 lines
37 KiB
Python

# Owner(s): ["module: sparse"]
import copy
import random
import torch
from torch import nn
from torch.ao.pruning._experimental.pruner import (
BaseStructuredSparsifier,
FakeStructuredSparsity,
FPGMPruner,
LSTMSaliencyPruner,
SaliencyPruner,
)
from torch.nn.utils import parametrize
from torch.testing._internal.common_pruning import (
Conv2dActivation,
Conv2dBias,
Conv2dPadBias,
Conv2dPool,
Conv2dPoolFlatten,
Conv2dPoolFlattenFunctional,
LinearActivation,
LinearActivationFunctional,
LinearBias,
LSTMLayerNormLinearModel,
LSTMLinearModel,
rows_are_subset,
SimpleConv2d,
SimpleLinear,
)
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TestCase,
)
DEVICES = {
torch.device("cpu"),
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
}
class SimplePruner(BaseStructuredSparsifier):
def update_mask(self, module, tensor_name, **kwargs):
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
class ImplementedPruner(BaseStructuredSparsifier):
def update_mask(self, module, tensor_name, **kwargs):
"""Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning"""
num_rows = len(module.parametrizations[tensor_name][0].mask)
prune = random.sample(list(range(num_rows)), num_rows // 3)
module.parametrizations[tensor_name][0].mask[prune] = False
class BottomHalfLSTMPruner(BaseStructuredSparsifier):
"""
Pruner that will remove the bottom half of the rows.
This is primarily meant for testing purposes
"""
def update_mask(self, module, tensor_name, **kwargs):
for p in getattr(module.parametrizations, tensor_name):
if isinstance(p, FakeStructuredSparsity):
mask = p.mask
masks = torch.split(mask, len(mask) // 4)
for small in masks:
num = len(small)
small[num // 2 :] = False
new_mask = torch.cat(masks)
mask.data = new_mask.data
class TestSaliencyPruner(TestCase):
def test_saliency_pruner_update_mask(self):
"""Test that we prune out the row with the lowest saliency (first row)"""
model = SimpleLinear()
with torch.no_grad():
model.linear1.weight = nn.Parameter(
torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
)
pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}]
pruner = SaliencyPruner({})
pruner.prepare(model, pruning_config)
pruner.enable_mask_update = True
pruner.step()
pruned_model = pruner.prune()
expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]])
pruned = pruned_model.linear1.weight
assert expected.shape == pruned.shape
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
def test_lstm_saliency_pruner_update_mask(self):
model = LSTMLinearModel(
input_dim=2,
hidden_dim=2,
output_dim=2,
num_layers=1,
)
manual_weights = torch.Tensor(
[[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]]
)
with torch.no_grad():
model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights))
model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0])
model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0])
config = [
{"tensor_fqn": "lstm.weight_ih_l0"},
{"tensor_fqn": "lstm.weight_hh_l0"},
]
lstm_input = torch.ones((1, 2))
fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5})
fx_pruner.prepare(model, config)
fx_pruner.enable_mask_update = True
fx_pruner.step()
model.eval()
pruned_model = fx_pruner.prune()
pruned_model.eval()
# make sure both models run
model(lstm_input)
pruned_model(lstm_input)
# make sure lowest saliency rows are pruned
expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]])
pruned = model.lstm.weight_ih_l0
assert expected.shape == pruned.shape
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
expected = torch.Tensor([[2], [2], [-2], [-2]])
pruned = model.lstm.weight_hh_l0
assert expected.shape == pruned.shape
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
expected = torch.Tensor([2, 2, -2, -2])
for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]:
assert expected.shape == pruned.shape
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
class TestBaseStructuredSparsifier(TestCase):
def _check_pruner_prepared(self, model, pruner, device):
for config in pruner.groups:
module = config["module"]
assert module.weight.device.type == device.type
# Check mask exists
assert config["tensor_fqn"] in pruner.state
# Check parametrization exists and is correct
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
def _check_pruner_valid_before_step(self, model, pruner, device):
for config in pruner.groups:
modules = []
if type(config["module"]) is tuple:
modules.extend(config["module"])
else:
module = config["module"]
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].mask.dtype == torch.bool
def _check_pruner_valid_after_step(self, model, pruner, mask, device):
for config in pruner.groups:
modules = []
if type(config["module"]) is tuple:
modules.extend(config["module"])
else:
module = config["module"]
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
total = module.parametrizations.weight[0].mask.numel()
assert (
module.parametrizations.weight[0].mask.count_nonzero()
== total - mask
)
def _test_constructor_on_device(self, model, device):
self.assertRaisesRegex(
TypeError,
"BaseStructuredSparsifier.*update_mask",
BaseStructuredSparsifier,
)
model1 = copy.deepcopy(model).to(device)
pruner = SimplePruner(None)
pruner.prepare(model1, None)
pruner.enable_mask_update = True
for g in pruner.groups:
module = g["module"]
assert module.weight.device.type == device.type
assert len(pruner.groups) == 5
pruner.step()
# Can instantiate the model with configs
model2 = copy.deepcopy(model).to(device)
pruner = SimplePruner({"test": 3})
pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}])
assert len(pruner.groups) == 1
assert pruner.groups[0]["module_fqn"] == "seq.0"
assert "test" in pruner.groups[0]
assert pruner.groups[0]["test"] == 3
def test_constructor(self):
model = SimpleLinear()
for device in DEVICES:
self._test_constructor_on_device(model, torch.device(device))
def _test_prepare_linear_on_device(self, model, device):
model = copy.deepcopy(model).to(device)
x = torch.ones(128, 7, device=device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (128, 10)
def test_prepare_linear(self):
models = [
SimpleLinear(),
LinearBias(),
LinearActivation(),
LinearActivationFunctional(),
] # without and with bias
for device in DEVICES:
for model in models:
self._test_prepare_linear_on_device(model, torch.device(device))
def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device):
x = torch.ones((1, 1, 28, 28), device=device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == expected_shape
def test_prepare_conv2d(self):
models = [
SimpleConv2d(),
Conv2dBias(),
Conv2dActivation(),
Conv2dPadBias(),
Conv2dPool(),
]
shapes = [
(1, 52, 20, 20),
(1, 52, 18, 18),
(1, 52, 18, 18),
(1, 52, 24, 24),
(1, 52, 3, 3),
]
configs = [None, None, None, None, None]
for device in DEVICES:
for model, shape, config in zip(models, shapes, configs):
model = model.to(device)
self._test_prepare_conv2d_on_device(
model, shape, config, torch.device(device)
)
def _test_step_linear_on_device(self, model, device):
model = model.to(device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
pruner.enable_mask_update = True
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, 1, device)
def test_step_linear(self):
models = [
SimpleLinear(),
LinearBias(),
LinearActivation(),
LinearActivationFunctional(),
]
for device in DEVICES:
for model in models:
self._test_step_linear_on_device(model, torch.device(device))
def _test_step_conv2d_on_device(self, model, expected_shape, config, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28), device=device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
pruner.enable_mask_update = True
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, 1, device)
assert model(x).shape == expected_shape
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
def test_step_conv2d(self):
models = [
SimpleConv2d(),
Conv2dBias(),
Conv2dActivation(),
Conv2dPadBias(),
Conv2dPool(),
]
shapes = [
(1, 52, 20, 20),
(1, 52, 18, 18),
(1, 52, 18, 18),
(1, 52, 24, 24),
(1, 52, 3, 3),
]
configs = [None, None, None, None, None]
for device in DEVICES:
for model, shape, config in zip(models, shapes, configs):
self._test_step_conv2d_on_device(
model, shape, config, torch.device(device)
)
def _check_pruner_pruned(self, model, pruner, device):
for config in pruner.groups:
module = config["module"]
assert not hasattr(module, "parametrizations")
assert not hasattr(module, "mask")
def _test_linear_on_device(
self, model, config, expected_shape, device, also_prune_bias
):
model = model.to(device)
model.eval()
num_original_params = sum(p.numel() for p in model.parameters())
x = torch.ones(128, 7, device=device)
pruner = ImplementedPruner({"prune_bias": also_prune_bias})
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
y_expected = model(x)
assert y_expected.shape == (128, 10)
self._check_pruner_prepared(model, pruner, device)
# Pruning step
pruned = pruner.prune()
y_pruned = pruned(x)
num_pruned_params = sum(p.numel() for p in pruned.parameters())
assert y_pruned.shape == expected_shape
self._check_pruner_pruned(model, pruner, device)
if y_pruned.shape == y_expected.shape:
assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all()
assert num_pruned_params < num_original_params
def test_prune_linear_linear(self):
r"""test pruning linear-> linear modules"""
configs, shapes = [], []
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
{"tensor_fqn": "seq.2.weight"},
]
)
shapes.append((128, 10))
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
{"tensor_fqn": "seq.2.weight"},
{"tensor_fqn": "linear1.weight"},
]
)
shapes.append((128, 10))
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.2.weight"},
]
)
shapes.append((128, 10))
for device in DEVICES:
for also_prune_bias in [True, False]:
for config, shape in zip(configs, shapes):
self._test_linear_on_device(
SimpleLinear(),
config,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_linear_bias_linear(self):
# linear(bias) -> linear(no bias)
configs, shapes = [], []
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
]
)
shapes.append((128, 10))
# linear(bias) -> linear(bias)
configs.append(
[
{"tensor_fqn": "seq.2.weight"},
{"tensor_fqn": "seq.3.weight"},
]
)
shapes.append((128, 10))
# linear(no bias) -> linear(bias)
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
{"tensor_fqn": "seq.2.weight"},
]
)
shapes.append((128, 10))
for device in DEVICES:
for also_prune_bias in [True, False]:
for config, shape in zip(configs, shapes):
self._test_linear_on_device(
LinearBias(),
config,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_linear_activation_linear(self):
config = [
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.2.weight"},
{"tensor_fqn": "seq.4.weight"},
{"tensor_fqn": "linear1.weight"},
]
shape = (128, 10)
for device in DEVICES:
for also_prune_bias in [True, False]:
# test version with nn.Modules
self._test_linear_on_device(
LinearActivation(),
config,
shape,
torch.device(device),
also_prune_bias,
)
# test functional version
self._test_linear_on_device(
LinearActivationFunctional(),
config,
shape,
torch.device(device),
also_prune_bias,
)
def _test_conv2d_on_device(
self, model, config, x, expected_shape, device, also_prune_bias
):
model = model.to(device)
num_original_params = sum(p.numel() for p in model.parameters())
model.eval()
pruner = ImplementedPruner({"prune_bias": also_prune_bias})
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
y_expected = model(x)
assert y_expected.shape == expected_shape
self._check_pruner_prepared(model, pruner, device)
# Fusion step
pruned = pruner.prune()
y_pruned = pruned(x)
num_pruned_params = sum(p.numel() for p in pruned.parameters())
assert y_pruned.shape == expected_shape
self._check_pruner_pruned(model, pruner, device)
if y_pruned.shape == y_expected.shape:
# TODO This rtol is a little high, need to double check if something specific is causing this to fail
assert torch.isclose(
y_expected,
y_pruned,
rtol=1e-3,
atol=1e-3,
).all(), f"fail for {type(model)}"
# only time this should be equal is when all layers have padding and we can't prune
assert num_pruned_params <= num_original_params
def test_prune_conv2d_conv2d(self):
configs, shapes = [], []
# all within sequential blocks
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
]
)
shapes.append((1, 52, 20, 20))
# prune across sequential blocks
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
{"tensor_fqn": "conv2d1.weight"},
]
)
shapes.append((1, 52, 20, 20))
for device in DEVICES:
x = torch.ones((1, 1, 28, 28), device=device)
for also_prune_bias in [True, False]:
for config, shape in zip(configs, shapes):
self._test_conv2d_on_device(
SimpleConv2d(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_conv2d_bias_conv2d(self):
# Conv2d with Bias and no Activation
configs, shapes = [], []
# conv2d(bias) -> conv2d(bias)
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
]
)
shapes.append((1, 52, 18, 18))
# conv2d(no bias) -> conv2d(bias)
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
{"tensor_fqn": "conv2d1.weight"},
]
)
shapes.append((1, 52, 18, 18))
# conv2d(bias) -> conv2d(no bias)
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.1.weight"},
{"tensor_fqn": "seq.2.weight"},
]
)
shapes.append((1, 52, 18, 18))
for device in DEVICES:
x = torch.ones((1, 1, 28, 28), device=device)
for also_prune_bias in [True, False]:
for config, shape in zip(configs, shapes):
self._test_conv2d_on_device(
Conv2dBias(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_conv2d_activation_conv2d(self):
# Conv2d with Activation and no Bias
configs, shapes = [], []
# conv2d(no bias) -> activation -> conv2d(no bias)
configs.append(
[
{"tensor_fqn": "seq.4.weight"},
]
)
shapes.append((1, 52, 18, 18))
# conv2d(bias) -> activation -> conv2d(bias)
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.2.weight"},
]
)
shapes.append((1, 52, 18, 18))
# conv2d(bias) -> activation -> conv2d(no bias)
configs.append(
[
{"tensor_fqn": "seq.2.weight"},
{"tensor_fqn": "seq.4.weight"},
]
)
shapes.append((1, 52, 18, 18))
# conv2d(no bias) -> activation -> conv2d(bias)
configs.append(
[
{"tensor_fqn": "conv2d1.weight"},
]
)
shapes.append((1, 52, 18, 18))
for device in DEVICES:
x = torch.ones((1, 1, 28, 28), device=device)
for also_prune_bias in [True, False]:
for config, shape in zip(configs, shapes):
self._test_conv2d_on_device(
Conv2dActivation(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_conv2d_padding_conv2d(self):
# Conv2d with Padded layers after Bias layers
configs, shapes = [], []
# conv(padded, bias) -> conv(padded, bias)
configs.append(
[
{"tensor_fqn": "seq.4.weight"},
]
)
shapes.append((1, 52, 24, 24))
# conv(no bias, no pad) -> conv(padded, bias)
configs.append(
[
{"tensor_fqn": "seq.2.weight"},
]
)
shapes.append((1, 52, 24, 24))
# conv(padded, bias) -> conv ( no bias ,no pad)
configs.append(
[
{"tensor_fqn": "seq.0.weight"},
]
)
shapes.append((1, 52, 24, 24))
# conv(pad, bias) -> conv(no pad, bias)
configs.append(
[
{"tensor_fqn": "seq.6.weight"},
]
)
shapes.append((1, 52, 24, 24))
# conv(no pad, bias) -> conv(pad, bias)
configs.append(
[
{"tensor_fqn": "seq.8.weight"},
]
)
shapes.append((1, 52, 24, 24))
for device in DEVICES:
x = torch.ones((1, 1, 28, 28), device=device)
for also_prune_bias in [True, False]:
for config, shape in zip(configs, shapes):
self._test_conv2d_on_device(
Conv2dPadBias(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_conv2d_pool_conv2d(self):
# Conv2d with Pooling layers
config = [
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.3.weight"},
{"tensor_fqn": "conv2d1.weight"},
{"tensor_fqn": "conv2d2.weight"},
]
shape = (1, 52, 3, 3)
for device in DEVICES:
x = torch.ones((1, 1, 28, 28), device=device)
for also_prune_bias in [True, False]:
self._test_conv2d_on_device(
Conv2dPool(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
def test_complex_conv2d(self):
"""Test fusion for models that contain Conv2d & Linear modules.
Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add"""
config = [
{"tensor_fqn": "seq.0.weight"},
{"tensor_fqn": "seq.3.weight"},
{"tensor_fqn": "conv2d1.weight"},
{"tensor_fqn": "conv2d2.weight"},
]
shape = (1, 13)
for device in DEVICES:
x = torch.ones((1, 1, 28, 28), device=device)
for also_prune_bias in [True, False]:
self._test_conv2d_on_device(
Conv2dPoolFlattenFunctional(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
self._test_conv2d_on_device(
Conv2dPoolFlatten(),
config,
x,
shape,
torch.device(device),
also_prune_bias,
)
def test_prune_lstm_linear_multiple_layer(self):
"""
Test fusion support for LSTM(multi-layer) -> Linear
"""
model = LSTMLinearModel(
input_dim=8,
hidden_dim=8,
output_dim=8,
num_layers=2,
)
config = [
{"tensor_fqn": "lstm.weight_ih_l0"},
{"tensor_fqn": "lstm.weight_hh_l0"},
{"tensor_fqn": "lstm.weight_ih_l1"},
{"tensor_fqn": "lstm.weight_hh_l1"},
]
lstm_input = torch.ones((1, 8))
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
fx_pruner.prepare(model, config)
fx_pruner.enable_mask_update = True
fx_pruner.step()
model.eval()
_, _ = model(lstm_input)
pruned_model = fx_pruner.prune()
pruned_model.eval()
_, _ = pruned_model(lstm_input)
expected_params = dict(model.named_parameters())
for name, param in model.named_parameters():
assert name in expected_params
# We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
# Instead we check that the weights of the new LSTM are a subset of the weights of
# the old LSTM
assert rows_are_subset(param, expected_params[name])
del expected_params[name]
# assert we haven't deleted any keys
assert len(expected_params) == 0
def test_prune_lstm_linear_single_layer(self):
"""
Test fusion support for LSTM (single-layer) -> Linear
"""
model = LSTMLinearModel(
input_dim=8,
hidden_dim=8,
output_dim=8,
num_layers=1,
)
config = [
{"tensor_fqn": "lstm.weight_ih_l0"},
{"tensor_fqn": "lstm.weight_hh_l0"},
]
lstm_input = torch.ones((1, 8))
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
fx_pruner.prepare(model, config)
fx_pruner.enable_mask_update = True
fx_pruner.step()
model.eval()
out_expected, lstm_out_expected = model(lstm_input)
pruned_model = fx_pruner.prune()
pruned_model.eval()
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
_, c = lstm_out_expected.size()
# We cannot check that y_expected == y_pruned as usual because
# zeros vs. missing elements yield different numerical results.
# Instead that we check that the pruned elements are the first half of the results
# since we are using a BottomHalfLSTMPruner
assert torch.isclose(
lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
).all()
# also check that output of linear is the same shape, this means we've resized
# linear columns correctly.
assert out_expected.shape == out_pruned.shape
def test_prune_lstm_layernorm_linear_multiple_layer(self):
"""
Test fusion support for LSTM(multi-layer) -> Linear
"""
model = LSTMLayerNormLinearModel(
input_dim=8,
output_dim=8,
hidden_dim=8,
num_layers=2,
)
config = [
{"tensor_fqn": "lstm.weight_ih_l0"},
{"tensor_fqn": "lstm.weight_hh_l0"},
{"tensor_fqn": "lstm.weight_ih_l1"},
{"tensor_fqn": "lstm.weight_hh_l1"},
]
lstm_input = torch.ones((1, 8))
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
fx_pruner.prepare(model, config)
fx_pruner.enable_mask_update = True
fx_pruner.step()
model.eval()
_, _ = model(lstm_input)
pruned_model = fx_pruner.prune()
pruned_model.eval()
_, _ = pruned_model(lstm_input)
expected_params = dict(model.named_parameters())
for name, param in model.named_parameters():
assert name in expected_params
# We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
# Instead we check that the weights of the new LSTM are a subset of the weights of
# the old LSTM
assert rows_are_subset(param, expected_params[name])
del expected_params[name]
# assert we haven't deleted any keys
assert len(expected_params) == 0
def test_prune_lstm_layernorm_linear_single_layer(self):
"""
Test fusion support for LSTM (single-layer) -> Linear
"""
model = LSTMLinearModel(
input_dim=8,
hidden_dim=8,
output_dim=8,
num_layers=1,
)
config = [
{"tensor_fqn": "lstm.weight_ih_l0"},
{"tensor_fqn": "lstm.weight_hh_l0"},
]
lstm_input = torch.ones((1, 8))
fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
fx_pruner.prepare(model, config)
fx_pruner.enable_mask_update = True
fx_pruner.step()
model.eval()
out_expected, lstm_out_expected = model(lstm_input)
pruned_model = fx_pruner.prune()
pruned_model.eval()
out_pruned, lstm_out_pruned = pruned_model(lstm_input)
_, c = lstm_out_expected.size()
# We cannot check that y_expected == y_pruned as usual because
# zeros vs. missing elements yield different numerical results.
# Instead that we check that the pruned elements are the first half of the results
# since we are using a BottomHalfLSTMPruner
assert torch.isclose(
lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
).all()
# also check that output of linear is the same shape, this means we've resized
# linear columns correctly.
assert out_expected.shape == out_pruned.shape
class TestFPGMPruner(TestCase):
"""
Test case for the implementation of paper:
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
"""
class SimpleConvFPGM(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv2d1 = nn.Conv2d(
in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
)
# Manually set the filter weights for demonstration purposes
"""
Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
Different from the norm-based decision that prunes filter with value 0.1,
FPGM will prune the one with value 2.0.
"""
weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter
weights = weights[:, None, None, None] # broadcasting
self.conv2d1.weight.data.copy_(
torch.ones(self.conv2d1.weight.shape) * weights
)
# Second Convolutional Layer
self.conv2d2 = nn.Conv2d(
in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False
)
weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
weights = weights[:, None, None, None]
self.conv2d2.weight.data.copy_(
torch.ones(self.conv2d2.weight.shape) * weights
)
def forward(self, x):
x = self.conv2d1(x)
x = self.conv2d2(x)
return x
def test_compute_distance(self, device="cpu"):
"""Test the distance computation function"""
model = TestFPGMPruner.SimpleConvFPGM().to(device)
pruner = FPGMPruner(0.3)
dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
# compute the distance matrix using torch.cdist
flattened_filters = torch.Tensor(
[
[
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
],
[
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
],
[
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
],
]
)
"""
Expected distance matrix should have the following values:
[0.0000, 3.0000, 8.7000],
[3.0000, 0.0000, 5.7000],
[8.7000, 5.7000, 0.0000],
the distance should therefore be:
[11.7000, 8.7000, 14.4000]
"""
expected_dist_matrix_conv1 = torch.cdist(
flattened_filters, flattened_filters, p=2
)
expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
assert torch.isclose(
dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07
).all()
def _test_update_mask_on_single_layer(self, expected_conv1, device):
"""Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
# test pruning with one layer of conv2d
model = TestFPGMPruner.SimpleConvFPGM().to(device)
x = torch.ones((1, 1, 32, 32), device=device)
pruner = FPGMPruner(0.3)
config = [{"tensor_fqn": "conv2d1.weight"}]
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
assert (
pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item()
is not False
), "do not prune the least-norm filter"
# fusion step
pruned_model = pruner.prune()
pruned_y = pruned_model(x)
# assert shapes
expected_conv1 = expected_conv1.to(device)
assert pruned_y.shape == (1, 4, 32, 32)
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
assert pruned_model.conv2d2.weight.shape == (
4,
2,
3,
3,
), "conv2d2 should have input channel pruned"
# assert value
assert torch.isclose(
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
).all()
def _test_update_mask_on_multiple_layer(
self, expected_conv1, expected_conv2, device
):
# the second setting
model = TestFPGMPruner.SimpleConvFPGM().to(device)
x = torch.ones((1, 1, 32, 32), device=device)
pruner = FPGMPruner(0.3)
config = [
{"tensor_fqn": "conv2d1.weight"},
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5},
]
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
# Get the masks for the two least-norm filters
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
# Check if either of the least-norm filters is not pruned
assert mask1.item() is not False or mask2.item() is not False, (
"Do not prune all least-norm filters"
)
# fusion step
pruned_model = pruner.prune()
pruned_y = pruned_model(x)
# assert shapes
expected_conv1 = expected_conv1.to(device)
expected_conv2 = expected_conv2.to(device)
assert pruned_y.shape == (1, 2, 32, 32)
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
# assert values
assert torch.isclose(
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
).all()
assert torch.isclose(
pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07
).all()
def test_update_mask(self):
weights = torch.tensor([3.0, 0.1])
expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None]
weights = torch.tensor([7.0, 0.4])
expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None]
for device in DEVICES:
self._test_update_mask_on_single_layer(expected_conv1, device)
self._test_update_mask_on_multiple_layer(
expected_conv1, expected_conv2, device
)
if __name__ == "__main__":
raise_on_run_directly("test/test_ao_sparsity.py")